diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aebdfc8..f327c33 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -127,7 +127,9 @@ jobs: with: project: ${{ matrix.pkg.dir }} - uses: julia-actions/julia-processcoverage@v1 + if: matrix.pkg.name == 'Arrow.jl' && matrix.version == '1' && matrix.os == 'macos-latest' && matrix.nthreads == 1 - uses: codecov/codecov-action@v5 + if: matrix.pkg.name == 'Arrow.jl' && matrix.version == '1' && matrix.os == 'macos-latest' && matrix.nthreads == 1 with: files: lcov.info test_monorepo: @@ -168,6 +170,50 @@ jobs: continue-on-error: false run: > julia --color=yes --project=monorepo -e 'using Pkg; Pkg.test("Arrow")' + flight_interop: + name: Arrow Flight interop - Julia 1 - ubuntu-latest + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: '3.11' + - name: Install Flight Python dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pyarrow grpcio grpcio-tools + - uses: julia-actions/setup-julia@v2 + with: + version: '1' + - uses: actions/cache@v5 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-flight-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-flight-${{ env.cache-name }}- + ${{ runner.os }}-flight- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1.6 + with: + project: . + - name: Dev local ArrowTypes for Arrow.jl tests + shell: julia --project=. {0} + run: | + using Pkg + Pkg.develop(PackageSpec(path="src/ArrowTypes")) + - name: Run Arrow Flight interop tests + env: + ARROW_FLIGHT_PYTHON: ${{ env.pythonLocation }}/bin/python + run: > + julia --color=yes --project=test -e 'using Pkg; + Pkg.develop(PackageSpec(path=".")); + Pkg.develop(PackageSpec(path="src/ArrowTypes")); + Pkg.instantiate(); + using Test, Arrow; + include("test/flight.jl")' docs: name: Documentation runs-on: ubuntu-latest diff --git a/.github/workflows/ci_nightly.yml b/.github/workflows/ci_nightly.yml index fb71886..dd869da 100644 --- a/.github/workflows/ci_nightly.yml +++ b/.github/workflows/ci_nightly.yml @@ -64,10 +64,6 @@ jobs: JULIA_NUM_THREADS: ${{ matrix.nthreads }} with: project: ${{ matrix.pkg.dir }} - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v5 - with: - files: lcov.info test_monorepo: name: Monorepo dev - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} @@ -106,3 +102,48 @@ jobs: continue-on-error: false run: > julia --color=yes --project=monorepo -e 'using Pkg; Pkg.test("Arrow")' + flight_interop: + name: Arrow Flight interop - Julia nightly - ubuntu-latest + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: '3.11' + - name: Install Flight Python dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pyarrow grpcio grpcio-tools + - uses: julia-actions/setup-julia@v2 + with: + version: 'nightly' + arch: x64 + - uses: actions/cache@v5 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-flight-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-flight-${{ env.cache-name }}- + ${{ runner.os }}-flight- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1.6 + with: + project: . + - name: Dev local ArrowTypes for Arrow.jl tests + shell: julia --project=. {0} + run: | + using Pkg + Pkg.develop(PackageSpec(path="src/ArrowTypes")) + - name: Run Arrow Flight interop tests + env: + ARROW_FLIGHT_PYTHON: ${{ env.pythonLocation }}/bin/python + run: > + julia --color=yes --project=test -e 'using Pkg; + Pkg.develop(PackageSpec(path=".")); + Pkg.develop(PackageSpec(path="src/ArrowTypes")); + Pkg.instantiate(); + using Test, Arrow; + include("test/flight.jl")' diff --git a/Project.toml b/Project.toml index b87ff79..8150433 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ version = "2.8.1" [deps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1" CodecLz4 = "5ba52731-8f18-5e0d-9241-30f10d1ec561" CodecZstd = "6b39b394-51ab-5f42-8807-6242bab2b4c2" @@ -28,6 +29,9 @@ ConcurrentUtilities = "f0e56b4a-5159-44fe-b623-3e5288b988bb" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" +gRPCClient = "aaca4a50-36af-4a1d-b878-4c443f2061ad" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" @@ -37,6 +41,15 @@ TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53" TranscodingStreams = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +[weakdeps] +gRPCServer = "608c6337-0d7d-447f-bb69-0f5674ee3959" + +[extensions] +ArrowgRPCServerExt = "gRPCServer" + +[sources] +ArrowTypes = { path = "src/ArrowTypes" } + [compat] ArrowTypes = "1.1,2" BitIntegers = "0.2, 0.3" @@ -45,10 +58,14 @@ CodecZstd = "0.7, 0.8" ConcurrentUtilities = "2" DataAPI = "1" EnumX = "1" +JSON3 = "1" +ProtoBuf = "~1.2.1" +gRPCClient = "1" +gRPCServer = "0.1" PooledArrays = "0.5, 1.0" SentinelArrays = "1" StringViews = "1" Tables = "1.1" TimeZones = "1" TranscodingStreams = "0.9.12, 0.10, 0.11" -julia = "1.9" +julia = "1.12" diff --git a/README.md b/README.md index 98bc9fd..7b8f4e6 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,8 @@ The package can be installed by typing in the following in a Julia REPL: julia> using Pkg; Pkg.add("Arrow") ``` +Arrow.jl currently requires Julia `1.12+`. + ## Local Development When developing on Arrow.jl it is recommended that you run the following to ensure that any @@ -49,23 +51,63 @@ changes to ArrowTypes.jl are immediately available to Arrow.jl without requiring julia --project -e 'using Pkg; Pkg.develop(path="src/ArrowTypes")' ``` +Current write-path notes: + * `Arrow.tobuffer` includes a direct single-partition fast path for eligible inputs + * `Arrow.tobuffer(Tables.partitioner(...))` also includes a targeted direct multi-record-batch path for single-column top-level strings and single-column non-missing binary/code-units columns + * `Arrow.write(io, Tables.partitioner(...))` now reuses that same targeted direct multi-record-batch path instead of always going through the legacy `Writer` orchestration + * multi-column partitions, dictionary-encoded top-level columns, map-heavy inputs, and missing-binary partitions retain the existing writer path + ## Format Support This implementation supports the 1.0 version of the specification, including support for: * All primitive data types * All nested data types * Dictionary encodings and messages + * Dictionary-encoded `CategoricalArray` interop, including missing-value roundtrips through `Arrow.Table`, `copy`, and `DataFrame(...; copycols=true)` * Extension types + * Lightweight schema/field metadata overlays via `Arrow.withmetadata(...)` for Tables.jl-compatible sources before serialization + * Base Julia `Enum` logical types via the `JuliaLang.Enum` extension label, with native Julia roundtrips back to the original enum type while `convert=false` and non-Julia consumers still see the primitive storage type + * View-backed Utf8/Binary columns, including recovery from under-reported variadic buffer counts by inferring the required external buffers from valid view elements * Streaming, file, record batch, and replacement and isdelta dictionary messages It currently doesn't include support for: - * Tensors or sparse tensors - * Flight RPC + * Tensor or sparse tensor IPC payload semantics; Arrow.jl now recognizes those message headers explicitly and rejects them with precise errors instead of falling through to a generic unsupported-message path * C data interface + * Writing Run-End Encoded arrays; Arrow.jl now reads REE arrays and exposes them as read-only vectors, but still rejects REE on write paths + +Flight RPC status: + * Experimental `Arrow.Flight` support is available in-tree + * Requires Julia `1.12+` + * Includes generated protocol bindings and complete client constructors for the `FlightService` RPC surface + * Keeps the top-level Flight module shell thin, with exports and generated-protocol setup split out of `src/flight/Flight.jl` + * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut/DoExchange payload generation, plus opt-in `app_metadata` surfacing through `include_app_metadata=true` on `Arrow.Flight.stream(...)` / `Arrow.Flight.table(...)`, explicit batch-wise `app_metadata=...` emission on `Arrow.Flight.flightdata(...)`, `Arrow.Flight.putflightdata!(...)`, and source-based `Arrow.Flight.doexchange(...)`, and a reusable `Arrow.Flight.withappmetadata(...)` wrapper so source-level batch metadata can stay attached without manual keyword threading + * Keeps the Flight IPC conversion layer modular under `src/flight/convert/`, with `src/flight/convert.jl` retained as a thin entrypoint + * Includes client helpers for request headers, binary metadata, handshake token reuse, and TLS configuration via `withheaders`, `withtoken`, and `authenticate` + * Keeps the Flight client implementation modular under `src/flight/client/`, with thin entrypoints at `src/flight/client.jl` and `src/flight/client/rpc_methods.jl` + * Includes a transport-agnostic server core (`Service`, `ServerCallContext`, `ServiceDescriptor`, `MethodDescriptor`) for local Flight method dispatch, path lookup, and handler testing + * Keeps the transport-agnostic server core modular under `src/flight/server/`, with `src/flight/server.jl` retained as a thin entrypoint + * Includes an optional `gRPCServer.jl` package extension that maps `Arrow.Flight.Service` into `gRPCServer.ServiceDescriptor` and registers Flight proto types with the external server package when it is present + * Keeps the optional `gRPCServer.jl` bridge modular under `ext/arrowgrpcserverext/`, with `ext/ArrowgRPCServerExt.jl` retained as a thin entrypoint + * Includes optional live interoperability coverage for `Handshake`, authenticated token propagation, `PollFlightInfo`, and TLS via dedicated Python reference servers + * Includes optional live `pyarrow.flight` interoperability coverage for `ListFlights`, `GetFlightInfo`, `GetSchema`, `DoGet`, `DoPut`, `DoExchange`, `ListActions`, and `DoAction` + * Keeps targeted Flight verification modular under `test/flight/`, with `test/flight.jl` retained as a thin entrypoint for local and CI invocation stability, the client-constructor/protocol-wrapper checks decomposed under `test/flight/client_surface/`, the optional `gRPCServer` extension scenarios decomposed under `test/flight/grpcserver_extension/`, the `pyarrow.flight` interop scenarios decomposed under `test/flight/pyarrow_interop/`, and the transport-agnostic server-core checks decomposed under `test/flight/server_core/` + * Includes `test/flight_grpcserver.jl` as a temporary-environment runner for optional native `gRPCServer` coverage without mutating `test/Project.toml` + * Dedicated CI jobs now exercise the Flight interop suite on stable and nightly Linux; native Julia server transport remains optional/experimental and is not part of the default Flight suite Third-party data formats: * CSV, parquet and avro support via the existing [CSV.jl](https://github.com/JuliaData/CSV.jl), [Parquet.jl](https://github.com/JuliaIO/Parquet.jl) and [Avro.jl](https://github.com/JuliaData/Avro.jl) packages * Other Tables.jl-compatible packages automatically supported ([DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), [JSONTables.jl](https://github.com/JuliaData/JSONTables.jl), [JuliaDB.jl](https://github.com/JuliaData/JuliaDB.jl), [SQLite.jl](https://github.com/JuliaDatabases/SQLite.jl), [MySQL.jl](https://github.com/JuliaDatabases/MySQL.jl), [JDBC.jl](https://github.com/JuliaDatabases/JDBC.jl), [ODBC.jl](https://github.com/JuliaDatabases/ODBC.jl), [XLSX.jl](https://github.com/felipenoris/XLSX.jl), etc.) * No current Julia packages support ORC +Canonical extension highlights: + * `UUID` now writes the canonical `arrow.uuid` extension name by default while retaining reader compatibility with legacy `JuliaLang.UUID` metadata + * `Arrow.TimestampWithOffset{U}` provides a canonical `arrow.timestamp_with_offset` logical type without conflating offset-only semantics with `ZonedDateTime` + * `Arrow.Bool8` provides an explicit opt-in writer/reader surface for the canonical `arrow.bool8` extension without changing the default packed-bit `Bool` path + * `Arrow.JSONText{String}` provides a text-backed logical type for the canonical `arrow.json` extension without parsing payloads during read or write + * `arrow.opaque` now reads as the underlying storage type without warning, and explicit writer metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` + * `Arrow.variantmetadata()`, `Arrow.fixedshapetensormetadata(...)`, and `Arrow.variableshapetensormetadata(...)` generate canonical metadata strings for advanced canonical extensions + * `arrow.fixed_shape_tensor` and `arrow.variable_shape_tensor` are recognized on read as canonical passthrough extensions over their storage types, and Arrow.jl now validates their canonical metadata plus top-level storage shape before accepting them + * `arrow.parquet.variant` is recognized on read as a canonical passthrough extension over its storage type; Arrow.jl currently validates that its canonical metadata is the required empty string, but does not yet implement deeper variant semantics or an automatic writer surface + * Legacy `JuliaLang.ZonedDateTime-UTC` and `JuliaLang.ZonedDateTime` files remain readable for backward compatibility + See the [full documentation](https://arrow.apache.org/julia/) for details on reading and writing arrow data. diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 6e32d07..728dee9 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -16,6 +16,7 @@ # under the License. Manifest.toml +*/Manifest.toml dev/release/apache-rat-*.jar dev/release/filtered_rat.txt dev/release/rat.xml diff --git a/docs/src/manual.md b/docs/src/manual.md index 5a3330f..62aef96 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -87,7 +87,14 @@ In the arrow data format, specific logical types are supported, a list of which * `Date`, `Time`, `Timestamp`, and `Duration` all have natural Julia defintions in `Dates.Date`, `Dates.Time`, `TimeZones.ZonedDateTime`, and `Dates.Period` subtypes, respectively. * `Char` and `Symbol` Julia types are mapped to arrow string types, with additional metadata of the original Julia type; this allows deserializing directly to `Char` and `Symbol` in Julia, while other language implementations will see these columns as just strings -* Similarly to the above, the `UUID` Julia type is mapped to a 128-bit `FixedSizeBinary` arrow type. +* `UUID` is mapped to a 128-bit `FixedSizeBinary` arrow type and now writes the canonical `arrow.uuid` extension name by default while still reading older `JuliaLang.UUID` metadata +* `Arrow.TimestampWithOffset{U}` is the canonical offset-only logical type for `arrow.timestamp_with_offset`; it stores a UTC `Arrow.Timestamp{U,:UTC}` plus `offset_minutes::Int16` and does not imply a timezone-name interpretation +* `Arrow.Bool8` is an explicit opt-in logical type for the canonical `arrow.bool8` extension; it uses `Int8` storage, while plain Julia `Bool` continues to use Arrow's packed-bit boolean layout +* `Arrow.JSONText{String}` is a text-backed logical type for the canonical `arrow.json` extension; Arrow.jl preserves the payload as text and does not parse JSON automatically +* `arrow.opaque` is treated as interoperability metadata over the underlying storage type; explicit metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` when writing +* `Arrow.variantmetadata()`, `Arrow.fixedshapetensormetadata(...)`, and `Arrow.variableshapetensormetadata(...)` generate canonical metadata strings for advanced canonical extensions when writing explicit storage-backed values +* `arrow.fixed_shape_tensor` and `arrow.variable_shape_tensor` are recognized as canonical passthrough extensions on read; Arrow.jl returns their underlying storage types, validates canonical metadata and top-level storage shape, and does not yet implement higher-level semantic interpretation or automatic writer surfaces for them +* `arrow.parquet.variant` is recognized as a canonical passthrough extension on read; Arrow.jl currently validates only the required empty metadata string and does not yet implement deeper variant semantics or an automatic writer surface * `Decimal128` and `Decimal256` have no corresponding builtin Julia types, so they're deserialized using a compatible type definition in Arrow.jl itself: `Arrow.Decimal` @@ -97,10 +104,48 @@ One note on performance: when writing `TimeZones.ZonedDateTime` columns to the a as the column has `ZonedDateTime` elements that all share a common timezone. This ensures the writing process can know "upfront" which timezone will be encoded and is thus much more efficient and performant. +Run-End Encoded arrays are now supported on the read path. Arrow.jl exposes REE +columns as read-only vectors and continues to reject REE on write paths, rather +than attempting a partial or lossy re-encoding. + +Tensor and SparseTensor IPC messages are still unsupported, but Arrow.jl now +recognizes those message headers explicitly and rejects them with precise +errors instead of falling through to a generic unsupported-message failure. + +Similarly, `ArrowTypes.ToArrow` avoids repeated type-promotion work for +homogeneous custom columns even when `ArrowTypes.ArrowType(T)` is abstract, so +write-time conversion does not pay unnecessary overhead once the serialized +element type is stable. + #### Custom types To support writing your custom Julia struct, Arrow.jl utilizes the format's mechanism for "extension types" by allowing the storing of Julia type name and metadata in the field metadata. To "hook in" to this machinery, custom types can utilize the interface methods defined in the `Arrow.ArrowTypes` submodule. For example: +Arrow.jl already uses this mechanism for several Base logical types, including +`nothing`, `Tuple`, `VersionNumber`, and `Complex`, so those values roundtrip as +their original Julia types instead of falling back to plain struct-shaped +`NamedTuple`s. + +Base Julia `@enum` types also work out of the box through the same extension +machinery. Arrow stores the enum as its primitive basetype plus a +`JuliaLang.Enum` extension label that records the qualified Julia type path and +label/value mapping. Native Julia readers reconstruct the enum type, while +`Arrow.Table(...; convert=false)` and non-Julia consumers continue to see the +primitive storage values. + +```julia +using Arrow + +@enum RankingStrategy lexical=1 semantic=2 hybrid=3 + +bytes = read(Arrow.tobuffer((strategy = [lexical, hybrid],))) +typed = Arrow.Table(IOBuffer(bytes)) +raw = Arrow.Table(IOBuffer(bytes); convert=false) + +eltype(typed.strategy) == RankingStrategy +eltype(raw.strategy) == Int32 +``` + ```julia using Arrow @@ -200,6 +245,23 @@ Arrow.jl provides a convenient accessor for this metadata via [`Arrow.getmetadat To attach custom schema/column metadata to Arrow tables at serialization time, see the `metadata` and `colmetadata` keyword arguments to [`Arrow.write`](@ref). +For lightweight overlays on existing Tables.jl sources, Arrow.jl also provides +`Arrow.withmetadata(table_like; metadata=..., colmetadata=...)`. This keeps any +existing schema/field metadata already exposed by the source, overlays new +entries on top, and returns a wrapper that can be passed directly to +[`Arrow.write`](@ref), `Arrow.tobuffer`, or the Flight IPC helpers. + +The Flight IPC helpers also expose batch-wise Flight `app_metadata`. +[`Arrow.Flight.stream`](@ref) and [`Arrow.Flight.table`](@ref) can surface it +with `include_app_metadata=true`, while [`Arrow.Flight.flightdata`](@ref), +[`Arrow.Flight.putflightdata!`](@ref), and source-based +[`Arrow.Flight.doexchange`](@ref) accept `app_metadata=...` to emit one payload +per record batch without dropping down to raw protocol messages. +[`Arrow.Flight.withappmetadata`](@ref) provides the same payload metadata as a +lightweight wrapper around a table or partitioned source, so the metadata can +ride with the source itself instead of being re-specified at every emit call +site. + ## Writing arrow data Ok, so that's a pretty good rundown of *reading* arrow data, but how do you *produce* arrow data? Enter `Arrow.write`. diff --git a/ext/ArrowgRPCServerExt.jl b/ext/ArrowgRPCServerExt.jl new file mode 100644 index 0000000..af16d70 --- /dev/null +++ b/ext/ArrowgRPCServerExt.jl @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module ArrowgRPCServerExt + +using Arrow +using gRPCServer + +include("arrowgrpcserverext/constants.jl") +include("arrowgrpcserverext/context.jl") +include("arrowgrpcserverext/streams.jl") +include("arrowgrpcserverext/handlers.jl") +include("arrowgrpcserverext/descriptor.jl") + +end # module ArrowgRPCServerExt diff --git a/ext/arrowgrpcserverext/constants.jl b/ext/arrowgrpcserverext/constants.jl new file mode 100644 index 0000000..3fd7ab4 --- /dev/null +++ b/ext/arrowgrpcserverext/constants.jl @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const Flight = Arrow.Flight +const STREAM_BUFFER_SIZE = 16 +const GENERATED_TYPE_PREFIX = "Arrow.Flight.Generated." diff --git a/ext/arrowgrpcserverext/context.jl b/ext/arrowgrpcserverext/context.jl new file mode 100644 index 0000000..88598e5 --- /dev/null +++ b/ext/arrowgrpcserverext/context.jl @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _method_type(method::Flight.MethodDescriptor) + if method.request_streaming + return method.response_streaming ? gRPCServer.MethodType.BIDI_STREAMING : + gRPCServer.MethodType.CLIENT_STREAMING + end + return method.response_streaming ? gRPCServer.MethodType.SERVER_STREAMING : + gRPCServer.MethodType.UNARY +end + +function _call_context(context::gRPCServer.ServerContext) + headers = Flight.HeaderPair[ + String(name) => (value isa String ? value : Vector{UInt8}(value)) for + (name, value) in pairs(context.metadata) + ] + peer = string(context.peer.address, ":", context.peer.port) + return Flight.ServerCallContext( + headers=headers, + peer=peer, + secure=(context.peer.certificate !== nothing), + ) +end + +function _proto_type_name(T::Type) + type_name = string(T) + if startswith(type_name, GENERATED_TYPE_PREFIX) + return type_name[(ncodeunits(GENERATED_TYPE_PREFIX) + 1):end] + end + return type_name +end diff --git a/ext/arrowgrpcserverext/descriptor.jl b/ext/arrowgrpcserverext/descriptor.jl new file mode 100644 index 0000000..d3b6459 --- /dev/null +++ b/ext/arrowgrpcserverext/descriptor.jl @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _register_proto_types!(method::Flight.MethodDescriptor) + registry = gRPCServer.get_type_registry() + registry[_proto_type_name(method.request_type)] = method.request_type + registry[_proto_type_name(method.response_type)] = method.response_type + return nothing +end + +function gRPCServer.service_descriptor(service::Flight.Service) + descriptor = Flight.servicedescriptor(service) + methods = Dict{String,gRPCServer.MethodDescriptor}() + for method in descriptor.methods + _register_proto_types!(method) + methods[method.name] = gRPCServer.MethodDescriptor( + method.name, + _method_type(method), + _proto_type_name(method.request_type), + _proto_type_name(method.response_type), + _handler(service, method), + ) + end + return gRPCServer.ServiceDescriptor(descriptor.name, methods, nothing) +end diff --git a/ext/arrowgrpcserverext/handlers.jl b/ext/arrowgrpcserverext/handlers.jl new file mode 100644 index 0000000..028107e --- /dev/null +++ b/ext/arrowgrpcserverext/handlers.jl @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _unary_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, request) -> + Flight.dispatch(service, _call_context(context), method, request) +end + +function _server_streaming_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, request, stream) -> begin + response = Channel{method.response_type}(STREAM_BUFFER_SIZE) + task = @async begin + try + if method.handler_field === :listactions + Flight.listactions(service, _call_context(context), response) + else + Flight.dispatch(service, _call_context(context), method, request, response) + end + finally + close(response) + end + end + try + _drain_response!(stream, response) + _streaming_handler_result(task) + gRPCServer.close!(stream) + finally + istaskdone(task) || wait(task) + end + end +end + +function _client_streaming_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, stream) -> begin + request = Channel{method.request_type}(STREAM_BUFFER_SIZE) + producer = @async begin + try + for message in stream + put!(request, message) + end + finally + close(request) + end + end + task = @async Flight.dispatch(service, _call_context(context), method, request) + try + return fetch(task) + finally + _streaming_handler_result(task, producer) + end + end +end + +function _bidi_streaming_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, stream) -> begin + request = Channel{method.request_type}(STREAM_BUFFER_SIZE) + response = Channel{method.response_type}(STREAM_BUFFER_SIZE) + producer = @async begin + try + for message in stream + put!(request, message) + end + finally + close(request) + end + end + task = @async begin + try + Flight.dispatch(service, _call_context(context), method, request, response) + finally + close(response) + end + end + try + for message in response + gRPCServer.send!(stream, message) + end + _streaming_handler_result(task, producer) + gRPCServer.close!(stream) + finally + istaskdone(task) || wait(task) + isnothing(producer) || (istaskdone(producer) || wait(producer)) + end + return nothing + end +end + +function _handler(service::Flight.Service, method::Flight.MethodDescriptor) + if !method.request_streaming && !method.response_streaming + return _unary_handler(service, method) + elseif !method.request_streaming && method.response_streaming + return _server_streaming_handler(service, method) + elseif method.request_streaming && !method.response_streaming + return _client_streaming_handler(service, method) + end + return _bidi_streaming_handler(service, method) +end diff --git a/ext/arrowgrpcserverext/streams.jl b/ext/arrowgrpcserverext/streams.jl new file mode 100644 index 0000000..22f7925 --- /dev/null +++ b/ext/arrowgrpcserverext/streams.jl @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _drain_response!(stream::gRPCServer.ServerStream, response::Channel) + # gRPCServer falls back to `ServerStream{Any}` when a descriptor only carries + # protobuf type names. Drain generically and let `send!` enforce compatibility. + for message in response + gRPCServer.send!(stream, message) + end + return nothing +end + +function _streaming_handler_result(task::Task, producer::Union{Nothing,Task}=nothing) + if !isnothing(producer) + if istaskfailed(producer) + throw(producer.exception) + end + wait(producer) + end + if istaskfailed(task) + throw(task.exception) + end + wait(task) + return nothing +end diff --git a/src/Arrow.jl b/src/Arrow.jl index 6f3ccdf..68ddda8 100644 --- a/src/Arrow.jl +++ b/src/Arrow.jl @@ -29,9 +29,27 @@ This implementation supports the 1.0 version of the specification, including sup It currently doesn't include support for: * Tensors or sparse tensors - * Flight RPC * C data interface +Flight RPC status: + * Experimental `Arrow.Flight` support is available in-tree + * Requires Julia `1.12+` + * Includes generated protocol bindings and client constructors for the `FlightService` RPC surface + * Keeps the top-level Flight module shell thin, with exports and generated-protocol setup split out of `src/flight/Flight.jl` + * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut payload generation + * Keeps the Flight IPC conversion layer modular under `src/flight/convert/`, with `src/flight/convert.jl` retained as a thin entrypoint + * Includes client helpers for request headers, binary metadata, handshake token reuse, and TLS configuration via `withheaders`, `withtoken`, and `authenticate` + * Keeps the Flight client implementation modular under `src/flight/client/`, with thin entrypoints at `src/flight/client.jl` and `src/flight/client/rpc_methods.jl` + * Includes a transport-agnostic server core (`Service`, `ServerCallContext`, `ServiceDescriptor`, `MethodDescriptor`) for local Flight method dispatch, path lookup, and handler testing + * Keeps the transport-agnostic server core modular under `src/flight/server/`, with `src/flight/server.jl` retained as a thin entrypoint + * Includes an optional `gRPCServer.jl` package extension that maps `Arrow.Flight.Service` into `gRPCServer.ServiceDescriptor` and registers Flight proto types with the external server package when it is present + * Keeps the optional `gRPCServer.jl` bridge modular under `ext/arrowgrpcserverext/`, with `ext/ArrowgRPCServerExt.jl` retained as a thin entrypoint + * Includes optional live interoperability coverage for `Handshake`, authenticated token propagation, `PollFlightInfo`, and TLS via dedicated Python reference servers + * Includes optional live `pyarrow.flight` interoperability coverage for `ListFlights`, `GetFlightInfo`, `GetSchema`, `DoGet`, `DoPut`, `DoExchange`, `ListActions`, and `DoAction` + * Keeps targeted Flight verification modular under `test/flight/`, with `test/flight.jl` retained as a thin entrypoint for local and CI invocation stability, the client-constructor/protocol-wrapper checks decomposed under `test/flight/client_surface/`, the optional `gRPCServer` extension scenarios decomposed under `test/flight/grpcserver_extension/`, the `pyarrow.flight` interop scenarios decomposed under `test/flight/pyarrow_interop/`, and the transport-agnostic server-core checks decomposed under `test/flight/server_core/` + * Includes `test/flight_grpcserver.jl` as a temporary-environment runner for optional native `gRPCServer` coverage without mutating `test/Project.toml` + * Dedicated CI jobs now exercise the Flight interop suite on stable and nightly Linux; native Julia server transport remains optional/experimental and is not part of the default Flight suite + Third-party data formats: * csv and parquet support via the existing [CSV.jl](https://github.com/JuliaData/CSV.jl) and [Parquet.jl](https://github.com/JuliaIO/Parquet.jl) packages * Other [Tables.jl](https://github.com/JuliaData/Tables.jl)-compatible packages automatically supported ([DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), [JSONTables.jl](https://github.com/JuliaData/JSONTables.jl), [JuliaDB.jl](https://github.com/JuliaData/JuliaDB.jl), [SQLite.jl](https://github.com/JuliaDatabases/SQLite.jl), [MySQL.jl](https://github.com/JuliaDatabases/MySQL.jl), [JDBC.jl](https://github.com/JuliaDatabases/JDBC.jl), [ODBC.jl](https://github.com/JuliaDatabases/ODBC.jl), [XLSX.jl](https://github.com/felipenoris/XLSX.jl), etc.) @@ -48,6 +66,7 @@ using DataAPI, Tables, SentinelArrays, PooledArrays, + JSON3, CodecLz4, CodecZstd, TimeZones, @@ -55,13 +74,15 @@ using DataAPI, ConcurrentUtilities, StringViews -export ArrowTypes +export ArrowTypes, Flight using Base: @propagate_inbounds import Base: == const FILE_FORMAT_MAGIC_BYTES = b"ARROW1" const CONTINUATION_INDICATOR_BYTES = 0xffffffff +const TENSOR_UNSUPPORTED = "Tensor messages are not supported yet" +const SPARSE_TENSOR_UNSUPPORTED = "SparseTensor messages are not supported yet" # vendored flatbuffers code for now include("FlatBuffers/FlatBuffers.jl") @@ -73,12 +94,16 @@ const Meta = Flatbuf using ArrowTypes include("utils.jl") +include("logicaltypes.jl") include("arraytypes/arraytypes.jl") include("eltypes.jl") +include("logicaltypes_builtin.jl") include("table.jl") +include("metadata/overlay.jl") include("write.jl") include("append.jl") include("show.jl") +include("flight/Flight.jl") const ZSTD_COMPRESSOR = Lockable{ZstdCompressor}[] const ZSTD_DECOMPRESSOR = Lockable{ZstdDecompressor}[] diff --git a/src/ArrowTypes/Project.toml b/src/ArrowTypes/Project.toml index 0166f60..50fd379 100644 --- a/src/ArrowTypes/Project.toml +++ b/src/ArrowTypes/Project.toml @@ -25,4 +25,4 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] -julia = "1.0" +julia = "1.12" diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl index fe2223f..825c6a5 100644 --- a/src/ArrowTypes/src/ArrowTypes.jl +++ b/src/ArrowTypes/src/ArrowTypes.jl @@ -213,6 +213,111 @@ arrowname(::Type{Char}) = CHAR JuliaType(::Val{CHAR}) = Char fromarrow(::Type{Char}, x::UInt32) = Char(x) +ArrowType(::Type{T}) where {T<:Enum} = Base.Enums.basetype(T) +toarrow(x::T) where {T<:Enum} = Base.Enums.basetype(T)(x) +const ENUM = Symbol("JuliaLang.Enum") +arrowname(::Type{T}) where {T<:Enum} = ENUM + +function _qualifiedtypepath(::Type{T}) where {T} + module_path = join(string.(Base.fullname(parentmodule(T))), ".") + return string(module_path, ".", nameof(T)) +end + +function _enum_labels(::Type{T}) where {T<:Enum} + B = Base.Enums.basetype(T) + return join((string(instance, ":", B(instance)) for instance in instances(T)), ",") +end + +function _parseenumlabels(labels::AbstractString, ::Type{B}) where {B<:Integer} + pairs = Pair{String,B}[] + isempty(labels) && return pairs + for entry in split(labels, ',') + isempty(entry) && return nothing + delimiter = findfirst(==(':'), entry) + delimiter === nothing && return nothing + label = entry[1:prevind(entry, delimiter)] + value = entry[nextind(entry, delimiter):end] + isempty(label) && return nothing + parsed = tryparse(B, value) + parsed === nothing && return nothing + push!(pairs, label => parsed) + end + return pairs +end + +function _enumlabelsmatch(::Type{T}, labels::AbstractString) where {T<:Enum} + B = Base.Enums.basetype(T) + parsed = _parseenumlabels(labels, B) + parsed === nothing && return false + expected = [string(instance) => B(instance) for instance in instances(T)] + length(parsed) == length(expected) || return false + parsed_dict = Dict(parsed) + length(parsed_dict) == length(parsed) || return false + return parsed_dict == Dict(expected) +end + +function arrowmetadata(::Type{T}) where {T<:Enum} + return string("type=", _qualifiedtypepath(T), ";labels=", _enum_labels(T)) +end + +function _parsemetadata(metadata::AbstractString) + parsed = Dict{String,String}() + isempty(metadata) && return parsed + for entry in split(metadata, ';') + isempty(entry) && continue + delimiter = findfirst(==('='), entry) + delimiter === nothing && continue + key = entry[1:prevind(entry, delimiter)] + value = entry[nextind(entry, delimiter):end] + parsed[key] = value + end + return parsed +end + +function _rootmodule(name::Symbol) + name === :Main && return Main + if isdefined(Main, name) + candidate = getfield(Main, name) + candidate isa Module && return candidate + end + try + return Base.root_module(Main, name) + catch + return nothing + end +end + +function _resolvequalifiedtype(path::AbstractString) + parts = split(path, '.') + length(parts) < 2 && return nothing + current = _rootmodule(Symbol(first(parts))) + current isa Module || return nothing + for part in parts[2:(end - 1)] + symbol = Symbol(part) + isdefined(current, symbol) || return nothing + current = getfield(current, symbol) + current isa Module || return nothing + end + type_symbol = Symbol(last(parts)) + isdefined(current, type_symbol) || return nothing + return getfield(current, type_symbol) +end + +function JuliaType(::Val{ENUM}, S, metadata::String) + parsed = _parsemetadata(metadata) + haskey(parsed, "type") || return nothing + haskey(parsed, "labels") || return nothing + T = _resolvequalifiedtype(parsed["type"]) + T isa DataType || return nothing + T <: Enum || return nothing + storage_type = Base.nonmissingtype(S) + Base.Enums.basetype(T) === storage_type || return nothing + _enumlabelsmatch(T, parsed["labels"]) || return nothing + return T +end + +fromarrow(::Type{T}, x::Integer) where {T<:Enum} = T(x) + "BoolKind data is stored with values packed down to individual bits; so instead of a traditional Bool being 1 byte/8 bits, 8 Bool values would be packed into a single byte" struct BoolKind <: ArrowKind end ArrowKind(::Type{Bool}) = BoolKind() @@ -264,9 +369,11 @@ ArrowKind(::Type{NTuple{N,T}}) where {N,T} = FixedSizeListKind{N,T}() ArrowKind(::Type{UUID}) = FixedSizeListKind{16,UInt8}() ArrowType(::Type{UUID}) = NTuple{16,UInt8} toarrow(x::UUID) = _cast(NTuple{16,UInt8}, x.value) -const UUIDSYMBOL = Symbol("JuliaLang.UUID") +const UUIDSYMBOL = Symbol("arrow.uuid") +const LEGACY_UUIDSYMBOL = Symbol("JuliaLang.UUID") arrowname(::Type{UUID}) = UUIDSYMBOL JuliaType(::Val{UUIDSYMBOL}) = UUID +JuliaType(::Val{LEGACY_UUIDSYMBOL}) = UUID fromarrow(::Type{UUID}, x::NTuple{16,UInt8}) = UUID(_cast(UInt128, x)) ArrowKind(::Type{IPv4}) = PrimitiveKind() @@ -324,6 +431,14 @@ arrowname(::Type{Tuple{}}) = TUPLE JuliaType(::Val{TUPLE}, ::Type{NamedTuple{names,types}}) where {names,types<:Tuple} = types fromarrow(::Type{T}, x::NamedTuple) where {T<:Tuple} = Tuple(x) +# Complex +const COMPLEX = Symbol("JuliaLang.Complex") +arrowname(::Type{<:Complex}) = COMPLEX +JuliaType(::Val{COMPLEX}, ::Type{NamedTuple{names,Tuple{T,T}}}) where {names,T<:Real} = + Complex{T} +fromarrowstruct(::Type{T}, ::Val{(:re, :im)}, re, im) where {T<:Complex} = T(re, im) +fromarrowstruct(::Type{T}, ::Val{(:im, :re)}, im, re) where {T<:Complex} = T(re, im) + # VersionNumber const VERSION_NUMBER = Symbol("JuliaLang.VersionNumber") ArrowKind(::Type{VersionNumber}) = StructKind() @@ -359,6 +474,7 @@ function default end default(T) = zero(T) default(::Type{Symbol}) = Symbol() default(::Type{Char}) = '\0' +default(::Type{T}) where {T<:Enum} = first(instances(T)) default(::Type{<:AbstractString}) = "" default(::Type{Any}) = nothing default(::Type{Missing}) = missing @@ -388,13 +504,67 @@ default(::Type{NamedTuple{names,types}}) where {names,types} = NamedTuple{names}(Tuple(default(fieldtype(types, i)) for i = 1:length(names))) function promoteunion(T, S) + T === S && return T new = promote_type(T, S) return isabstracttype(new) ? Union{T,S} : new end +function _toarroweltype(x) + state = iterate(x) + state === nothing && return Missing + y, st = state + srcT = Union{} + stable = false + T = Missing + if y !== missing + srcT = typeof(y) + mapped = ArrowType(srcT) + stable = isconcretetype(mapped) + T = stable ? mapped : typeof(toarrow(y)) + end + while true + state = iterate(x, st) + state === nothing && return T + y, st = state + if y === missing + S = Missing + elseif srcT === Union{} + srcT = typeof(y) + mapped = ArrowType(srcT) + stable = isconcretetype(mapped) + S = stable ? mapped : typeof(toarrow(y)) + elseif stable && typeof(y) === srcT + continue + else + S = typeof(toarrow(y)) + if stable && typeof(y) !== srcT + stable = false + end + end + S === T && continue + T = promoteunion(T, S) + end +end + +@inline _hasoffsetaxes(data) = Base.has_offset_axes(data) +@inline _offsetshift(data) = _hasoffsetaxes(data) ? firstindex(data) - 1 : 0 +@inline _hasonebasedaxes(data) = !_hasoffsetaxes(data) + # lazily call toarrow(x) on getindex for each x in data struct ToArrow{T,A} <: AbstractVector{T} data::A + offset::Int + needsconvert::Bool +end +@inline _sourcedata(x::ToArrow) = getfield(x, :data) +@inline _sourceoffset(x::ToArrow) = getfield(x, :offset) +@inline _needsconvert(x::ToArrow) = getfield(x, :needsconvert) +@inline _sourcevalue(x::ToArrow, i::Integer) = + @inbounds getindex(_sourcedata(x), i + _sourceoffset(x)) + +function ToArrow{T,A}(data::A) where {T,A} + needsconvert = !(eltype(A) === T && concrete_or_concreteunion(T)) + return ToArrow{T,A}(data, _offsetshift(data), needsconvert) end concrete_or_concreteunion(T) = @@ -404,15 +574,14 @@ concrete_or_concreteunion(T) = function ToArrow(x::A) where {A} S = eltype(A) T = ArrowType(S) - fi = firstindex(x) - if S === T && concrete_or_concreteunion(S) && fi == 1 + if S === T && concrete_or_concreteunion(S) && _hasonebasedaxes(x) return x elseif !concrete_or_concreteunion(T) # arrow needs concrete types, so try to find a concrete common type, preferring unions if isempty(x) return Missing[] end - T = mapreduce(typeof ∘ toarrow, promoteunion, x) + T = _toarroweltype(x) if T === Missing && concrete_or_concreteunion(S) T = promoteunion(T, typeof(toarrow(default(S)))) end @@ -440,7 +609,29 @@ function _convert(::Type{T}, x) where {T} return convert(T, x) end end -Base.getindex(x::ToArrow{T}, i::Int) where {T} = - _convert(T, toarrow(getindex(x.data, i + firstindex(x.data) - 1))) + +@inline function _toarrowvalue(x::ToArrow{T}, value) where {T} + _needsconvert(x) || return value + return _convert(T, toarrow(value)) +end + +Base.@propagate_inbounds function Base.getindex(x::ToArrow{T}, i::Int) where {T} + value = _sourcevalue(x, i) + return _toarrowvalue(x, value) +end + +function Base.iterate(x::ToArrow) + state = iterate(x.data) + state === nothing && return nothing + value, st = state + return _toarrowvalue(x, value), st +end + +function Base.iterate(x::ToArrow, st) + state = iterate(x.data, st) + state === nothing && return nothing + value, st = state + return _toarrowvalue(x, value), st +end end # module ArrowTypes diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl index 22d8dd0..3363b4b 100644 --- a/src/ArrowTypes/test/tests.jl +++ b/src/ArrowTypes/test/tests.jl @@ -22,6 +22,22 @@ struct Person name::String end +module EnumTestModule +@enum RankingStrategy lexical=1 semantic=2 hybrid=3 +end + +module WideEnumTestModule +@enum WideRanking::UInt64 small=1 colossal=0xffffffffffffffff +end + +const RankingStrategy = EnumTestModule.RankingStrategy +const lexical = EnumTestModule.lexical +const semantic = EnumTestModule.semantic +const hybrid = EnumTestModule.hybrid +const WideRanking = WideEnumTestModule.WideRanking +const small = WideEnumTestModule.small +const colossal = WideEnumTestModule.colossal + @testset "ArrowTypes" begin @test ArrowTypes.ArrowKind(MyInt) == ArrowTypes.PrimitiveKind() @test ArrowTypes.ArrowKind(Person) == ArrowTypes.StructKind() @@ -67,6 +83,44 @@ end @test ArrowTypes.JuliaType(Val(ArrowTypes.CHAR)) == Char @test ArrowTypes.fromarrow(Char, UInt32('1')) == '1' + enum_metadata = ArrowTypes.arrowmetadata(RankingStrategy) + @test ArrowTypes.ArrowKind(RankingStrategy) == ArrowTypes.PrimitiveKind() + @test ArrowTypes.ArrowType(RankingStrategy) == Int32 + @test ArrowTypes.toarrow(hybrid) == Int32(3) + @test ArrowTypes.arrowname(RankingStrategy) == ArrowTypes.ENUM + @test occursin("type=Main.EnumTestModule.RankingStrategy", enum_metadata) + @test occursin("labels=lexical:1,semantic:2,hybrid:3", enum_metadata) + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, enum_metadata) == + RankingStrategy + reordered_enum_metadata = "type=Main.EnumTestModule.RankingStrategy;labels=semantic:2,hybrid:3,lexical:1" + mismatched_enum_metadata = "type=Main.EnumTestModule.RankingStrategy;labels=lexical:1,semantic:2,hybrid:4" + malformed_enum_metadata = "type=Main.EnumTestModule.RankingStrategy;labels=lexical:1,semantic:nope" + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, reordered_enum_metadata) == + RankingStrategy + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, mismatched_enum_metadata) === + nothing + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, malformed_enum_metadata) === + nothing + @test ArrowTypes.JuliaType( + Val(ArrowTypes.ENUM), + Int32, + "type=Main.EnumTestModule.RankingStrategy", + ) === nothing + @test ArrowTypes.fromarrow(RankingStrategy, Int32(2)) == semantic + @test ArrowTypes.default(RankingStrategy) == lexical + + wide_enum_metadata = ArrowTypes.arrowmetadata(WideRanking) + @test ArrowTypes.ArrowKind(WideRanking) == ArrowTypes.PrimitiveKind() + @test ArrowTypes.ArrowType(WideRanking) == UInt64 + @test ArrowTypes.toarrow(colossal) == typemax(UInt64) + @test ArrowTypes.arrowname(WideRanking) == ArrowTypes.ENUM + @test occursin("type=Main.WideEnumTestModule.WideRanking", wide_enum_metadata) + @test occursin("labels=small:1,colossal:18446744073709551615", wide_enum_metadata) + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), UInt64, wide_enum_metadata) == + WideRanking + @test ArrowTypes.fromarrow(WideRanking, typemax(UInt64)) == colossal + @test ArrowTypes.default(WideRanking) == small + @test ArrowTypes.ArrowKind(Bool) == ArrowTypes.BoolKind() @test ArrowTypes.ListKind() == ArrowTypes.ListKind{false}() @@ -106,6 +160,7 @@ end @test ArrowTypes.toarrow(u) == ubytes @test ArrowTypes.arrowname(UUID) == ArrowTypes.UUIDSYMBOL @test ArrowTypes.JuliaType(Val(ArrowTypes.UUIDSYMBOL)) == UUID + @test ArrowTypes.JuliaType(Val(ArrowTypes.LEGACY_UUIDSYMBOL)) == UUID @test ArrowTypes.fromarrow(UUID, ubytes) == u ip4 = IPv4(rand(UInt32)) @@ -144,6 +199,17 @@ end @test ArrowTypes.default(Tuple{Vararg{Int}}) == () @test ArrowTypes.default(Tuple{String,Vararg{Int}}) == ("",) + z = 1.0 + 2.0im + @test ArrowTypes.ArrowKind(typeof(z)) == ArrowTypes.StructKind() + @test ArrowTypes.arrowname(typeof(z)) == ArrowTypes.COMPLEX + @test ArrowTypes.arrowname(Union{Missing,typeof(z)}) == ArrowTypes.COMPLEX + @test ArrowTypes.JuliaType( + Val(ArrowTypes.COMPLEX), + NamedTuple{(:re, :im),Tuple{Float64,Float64}}, + ) == ComplexF64 + @test ArrowTypes.fromarrowstruct(ComplexF64, Val((:re, :im)), 1.0, 2.0) == z + @test ArrowTypes.fromarrowstruct(ComplexF64, Val((:im, :re)), 2.0, 1.0) == z + v = v"1" v_nt = (major=1, minor=0, patch=0, prerelease=(), build=()) @test ArrowTypes.ArrowKind(VersionNumber) == ArrowTypes.StructKind() @@ -167,39 +233,97 @@ end @test ArrowTypes.promoteunion(Int, Float64) == Float64 @test ArrowTypes.promoteunion(Int, String) == Union{Int,String} + @test ArrowTypes.promoteunion(Int, Int) == Int @test ArrowTypes.concrete_or_concreteunion(Int) @test !ArrowTypes.concrete_or_concreteunion(Union{Real,String}) @test !ArrowTypes.concrete_or_concreteunion(Any) @testset "ToArrow" begin + @test !ArrowTypes._hasoffsetaxes([1, 2, 3]) + @test ArrowTypes._offsetshift([1, 2, 3]) == 0 + x = ArrowTypes.ToArrow([1, 2, 3]) @test x isa Vector{Int} @test x == [1, 2, 3] + baseview = @view [1, 2, 3][1:3] + x = ArrowTypes.ToArrow(baseview) + @test x === baseview + x = ArrowTypes.ToArrow([:hey, :ho]) @test x isa ArrowTypes.ToArrow{String,Vector{Symbol}} @test eltype(x) == String + @test ArrowTypes._needsconvert(x) + @test x[1] == "hey" + @test collect(x) == ["hey", "ho"] @test x == ["hey", "ho"] x = ArrowTypes.ToArrow(Any[1, 3.14]) @test x isa ArrowTypes.ToArrow{Float64,Vector{Any}} @test eltype(x) == Float64 + @test collect(x) == [1.0, 3.14] @test x == [1.0, 3.14] + x = ArrowTypes.ToArrow(Any[UUID(UInt128(1)), UUID(UInt128(2))]) + @test x isa ArrowTypes.ToArrow{NTuple{16,UInt8},Vector{Any}} + @test eltype(x) == NTuple{16,UInt8} + @test collect(x) == + [ArrowTypes.toarrow(UUID(UInt128(1))), ArrowTypes.toarrow(UUID(UInt128(2)))] + + x = ArrowTypes.ToArrow(Any[missing, UUID(UInt128(1))]) + @test x isa ArrowTypes.ToArrow{Union{Missing,NTuple{16,UInt8}},Vector{Any}} + @test eltype(x) == Union{Missing,NTuple{16,UInt8}} + @test isequal( + collect(x), + Union{Missing,NTuple{16,UInt8}}[missing, ArrowTypes.toarrow(UUID(UInt128(1)))], + ) + x = ArrowTypes.ToArrow(Any[1, 3.14, "hey"]) @test x isa ArrowTypes.ToArrow{Union{Float64,String},Vector{Any}} @test eltype(x) == Union{Float64,String} + @test collect(x) == Union{Float64,String}[1.0, 3.14, "hey"] @test x == [1.0, 3.14, "hey"] + x = ArrowTypes.ToArrow(Any[UUID(UInt128(1)), "tail"]) + @test x isa ArrowTypes.ToArrow{Union{NTuple{16,UInt8},String},Vector{Any}} + @test eltype(x) == Union{NTuple{16,UInt8},String} + @test collect(x) == + Union{NTuple{16,UInt8},String}[ArrowTypes.toarrow(UUID(UInt128(1))), "tail"] + x = ArrowTypes.ToArrow(OffsetArray([1, 2, 3], -3:-1)) @test x isa ArrowTypes.ToArrow{Int,OffsetVector{Int,Vector{Int}}} + @test ArrowTypes._hasoffsetaxes(getfield(x, :data)) + @test getfield(x, :offset) == ArrowTypes._offsetshift(getfield(x, :data)) + @test ArrowTypes._sourcedata(x) === getfield(x, :data) + @test ArrowTypes._sourceoffset(x) == getfield(x, :offset) + @test !ArrowTypes._needsconvert(x) + @test ArrowTypes._sourcevalue(x, 1) == 1 @test eltype(x) == Int + @test x[1] == 1 + @test x[3] == 3 + @test collect(x) == [1, 2, 3] @test x == [1, 2, 3] + x = ArrowTypes.ToArrow(OffsetArray(Union{Missing,Int}[1, missing], -3:-2)) + @test x isa ArrowTypes.ToArrow{ + Union{Missing,Int}, + OffsetVector{Union{Missing,Int},Vector{Union{Missing,Int}}}, + } + @test !ArrowTypes._needsconvert(x) + @test x[1] == 1 + @test x[2] === missing + @test isequal(collect(x), Union{Missing,Int}[1, missing]) + x = ArrowTypes.ToArrow(OffsetArray(Any[1, 3.14], -3:-2)) @test x isa ArrowTypes.ToArrow{Float64,OffsetVector{Any,Vector{Any}}} + @test getfield(x, :offset) == ArrowTypes._offsetshift(getfield(x, :data)) + @test ArrowTypes._sourcevalue(x, 2) == 3.14 @test eltype(x) == Float64 + @test ArrowTypes._needsconvert(x) + @test x[1] == 1 + @test x[2] == 3.14 + @test collect(x) == [1.0, 3.14] @test x == [1, 3.14] @testset "respect non-missing concrete type" begin @@ -219,6 +343,15 @@ end T = Union{DateTimeTZ,Missing} @test !ArrowTypes.concrete_or_concreteunion(ArrowTypes.ArrowType(T)) @test eltype(ArrowTypes.ToArrow(T[missing])) == Union{Timestamp{:UTC},Missing} + @test eltype( + ArrowTypes.ToArrow(DateTimeTZ[DateTimeTZ(1, "UTC"), DateTimeTZ(2, "UTC")]), + ) == Timestamp{:UTC} + @test eltype( + ArrowTypes.ToArrow(DateTimeTZ[DateTimeTZ(1, "UTC"), DateTimeTZ(2, "PST")]), + ) == Timestamp + @test eltype( + ArrowTypes.ToArrow(Any[DateTimeTZ(1, "UTC"), DateTimeTZ(2, "UTC")]), + ) == Timestamp{:UTC} # Works since `ArrowTypes.default(Any) === nothing` and # `ArrowTypes.toarrow(nothing) === missing`. Defining `toarrow(::Nothing) = nothing` diff --git a/src/arraytypes/arraytypes.jl b/src/arraytypes/arraytypes.jl index 58bab08..281dab1 100644 --- a/src/arraytypes/arraytypes.jl +++ b/src/arraytypes/arraytypes.jl @@ -99,17 +99,11 @@ function arrowvector( dictencode=dictencode, kw..., ) - elseif !(x isa DictEncode) + elseif !(x isa DictEncode) && !_keeprawmapvector(T, x) x = ToArrow(x) end S = maybemissing(eltype(x)) - if ArrowTypes.hasarrowname(T) - meta = _arrowtypemeta( - _normalizemeta(meta), - String(ArrowTypes.arrowname(T)), - String(ArrowTypes.arrowmetadata(T)), - ) - end + meta = _extensionmetadatafor(T, _normalizemeta(meta)) return arrowvector( S, x, @@ -133,15 +127,62 @@ _normalizecolmeta(colmeta) = toidict( Symbol(k) => toidict(String(v1) => String(v2) for (v1, v2) in v) for (k, v) in colmeta ) -function _arrowtypemeta(::Nothing, n, m) - return toidict(("ARROW:extension:name" => n, "ARROW:extension:metadata" => m)) +@inline function _materializeconverted(x::ArrowTypes.ToArrow) + data = ArrowTypes._sourcedata(x) + if ArrowTypes._needsconvert(x) && !ArrowTypes.concrete_or_concreteunion(eltype(data)) + return _materializeconverted(eltype(x), x) + end + return x end -function _arrowtypemeta(meta, n, m) - dict = Dict(meta) - dict["ARROW:extension:name"] = n - dict["ARROW:extension:metadata"] = m - return toidict(dict) +function _materializeconverted(::Type{T}, x::ArrowTypes.ToArrow{T,A}) where {T,A} + len = length(x) + data = Vector{T}(undef, len) + source = ArrowTypes._sourcedata(x) + i = 1 + for value in source + @inbounds data[i] = + value isa T ? value : ArrowTypes._convert(T, ArrowTypes.toarrow(value)) + i += 1 + end + return data +end + +@inline function _materializefixedbytes16(value) + if value isa ArrowTypes.UUID + return ArrowTypes._cast(NTuple{16,UInt8}, value.value) + elseif value isa NTuple{16,UInt8} + return value + else + return ArrowTypes._convert(NTuple{16,UInt8}, ArrowTypes.toarrow(value)) + end +end + +function _materializeconverted( + ::Type{NTuple{16,UInt8}}, + x::ArrowTypes.ToArrow{NTuple{16,UInt8},A}, +) where {A} + len = length(x) + data = Vector{NTuple{16,UInt8}}(undef, len) + source = ArrowTypes._sourcedata(x) + i = 1 + for value in source + @inbounds data[i] = _materializefixedbytes16(value) + i += 1 + end + return data +end + +@inline _toarrowvaliditysource(x::ArrowTypes.ToArrow) = + ArrowTypes._needsconvert(x) ? x : ArrowTypes._sourcedata(x) + +@inline _toarrowvalidity(x::ArrowTypes.ToArrow, data) = + data === x ? ValidityBitmap(x) : ValidityBitmap(data) + +@inline function _keeprawmapvector(::Type{T}, x) where {T} + return Base.has_offset_axes(x) && + ArrowTypes.concrete_or_concreteunion(T) && + ArrowKind(T) isa ArrowTypes.MapKind end # now we check for ArrowType converions and dispatch on ArrowKind @@ -201,15 +242,10 @@ end Base.size(p::ValidityBitmap) = (p.ℓ,) nullcount(x::ValidityBitmap) = x.nc -function ValidityBitmap(x) - T = eltype(x) - if !(T >: Missing) - return ValidityBitmap(UInt8[], 1, length(x), 0) - end +function _validitybitmap(x, len) len = length(x) blen = cld(len, 8) bytes = Vector{UInt8}(undef, blen) - st = iterate(x) nc = 0 b = 0xff j = k = 1 @@ -232,6 +268,23 @@ function ValidityBitmap(x) return ValidityBitmap(nc == 0 ? UInt8[] : bytes, 1, nc == 0 ? 0 : len, nc) end +function ValidityBitmap(x) + T = eltype(x) + if !(T >: Missing) + return ValidityBitmap(UInt8[], 1, length(x), 0) + end + return _validitybitmap(x, length(x)) +end + +function ValidityBitmap(x::ArrowTypes.ToArrow) + T = eltype(x) + if !(T >: Missing) + return ValidityBitmap(UInt8[], 1, length(x), 0) + end + source = _toarrowvaliditysource(x) + return _validitybitmap(source, length(x)) +end + @propagate_inbounds function Base.getindex(p::ValidityBitmap, i::Integer) # no boundscheck because parent array should do it # if a validity bitmap is empty, it either means: @@ -272,3 +325,4 @@ include("struct.jl") include("unions.jl") include("dictencoding.jl") include("views.jl") +include("runendencoded.jl") diff --git a/src/arraytypes/bool.jl b/src/arraytypes/bool.jl index 29c1505..8a33668 100644 --- a/src/arraytypes/bool.jl +++ b/src/arraytypes/bool.jl @@ -52,9 +52,7 @@ end arrowvector(::BoolKind, x::BoolVector, i, nl, fi, de, ded, meta; kw...) = x -function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...) - validity = ValidityBitmap(x) - len = length(x) +function _packboolbytes(x, len) blen = cld(len, 8) bytes = Vector{UInt8}(undef, blen) b = 0xff @@ -74,9 +72,25 @@ function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...) if j > 1 bytes[k] = b end + return bytes +end + +function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...) + validity = ValidityBitmap(x) + len = length(x) + bytes = _packboolbytes(x, len) return BoolVector{eltype(x)}(bytes, 1, validity, len, meta) end +function arrowvector(::BoolKind, x::ArrowTypes.ToArrow, i, nl, fi, de, ded, meta; kw...) + data = _materializeconverted(x) + validity = _toarrowvalidity(x, data) + len = length(data) + source = data === x ? _toarrowvaliditysource(x) : data + bytes = _packboolbytes(source, len) + return BoolVector{eltype(data)}(bytes, 1, validity, len, meta) +end + function compress(Z::Meta.CompressionType.T, comp, p::P) where {P<:BoolVector} len = length(p) nc = nullcount(p) diff --git a/src/arraytypes/dictencoding.jl b/src/arraytypes/dictencoding.jl index 3e3576c..c8582e8 100644 --- a/src/arraytypes/dictencoding.jl +++ b/src/arraytypes/dictencoding.jl @@ -119,6 +119,8 @@ signedtype(::Type{UInt32}) = Int32 signedtype(::Type{UInt64}) = Int64 signedtype(::Type{T}) where {T<:Signed} = T +@inline _dictrefshift(pool) = firstindex(pool) + indtype(d::DictEncoded{T,S,A}) where {T,S,A} = S indtype(c::Compressed{Z,A}) where {Z,A<:DictEncoded} = indtype(c.data) @@ -232,7 +234,7 @@ function arrowvector( inds = copyto!(similar(Vector{signedtype(length(pool))}, length(refa)), refa) end # adjust to "offset" instead of index - inds .-= firstindex(refa) + inds .-= _dictrefshift(pool) data = arrowvector( pool, i, diff --git a/src/arraytypes/fixedsizelist.jl b/src/arraytypes/fixedsizelist.jl index 2558dd5..4e8f74c 100644 --- a/src/arraytypes/fixedsizelist.jl +++ b/src/arraytypes/fixedsizelist.jl @@ -81,6 +81,8 @@ struct ToFixedSizeList{T,N,A} <: AbstractVector{T} end origtype(::ToFixedSizeList{T,N,A}) where {T,N,A} = eltype(A) +@inline _fixedsizedata(A::ToFixedSizeList) = getfield(A, :data) +@inline _fixedsizevalue(A::ToFixedSizeList, i::Integer) = @inbounds _fixedsizedata(A)[i] function ToFixedSizeList(input) NT = ArrowTypes.ArrowKind(Base.nonmissingtype(eltype(input))) # typically NTuple{N, T} @@ -90,7 +92,7 @@ function ToFixedSizeList(input) end Base.IndexStyle(::Type{<:ToFixedSizeList}) = Base.IndexLinear() -Base.size(x::ToFixedSizeList{T,N}) where {T,N} = (N * length(x.data),) +Base.size(x::ToFixedSizeList{T,N}) where {T,N} = (N * length(_fixedsizedata(x)),) Base.@propagate_inbounds function Base.getindex( A::ToFixedSizeList{T,N}, @@ -98,7 +100,7 @@ Base.@propagate_inbounds function Base.getindex( ) where {T,N} @boundscheck checkbounds(A, i) a, b = fldmod1(i, N) - @inbounds x = A.data[a] + x = _fixedsizevalue(A, a) return @inbounds x === missing ? ArrowTypes.default(T) : x[b] end @@ -108,7 +110,7 @@ end (i, chunk, chunk_i, len)=(1, 1, 1, length(A)), ) where {T,N} i > len && return nothing - @inbounds y = A.data[chunk] + y = _fixedsizevalue(A, chunk) @inbounds x = y === missing ? ArrowTypes.default(T) : y[chunk_i] if chunk_i == N chunk += 1 @@ -119,8 +121,60 @@ end return x, (i + 1, chunk, chunk_i, len) end +@inline function _writefixedsizechunk(io::IO, chunk::NTuple{N,UInt8}) where {N} + ref = Ref(chunk) + GC.@preserve ref begin + return Base.unsafe_write(io, Base.unsafe_convert(Ptr{UInt8}, ref), N) + end +end + +@inline function _writefixedsizecontiguous(io::IO, data::Vector{NTuple{N,UInt8}}) where {N} + GC.@preserve data begin + return Base.unsafe_write(io, Ptr{UInt8}(pointer(data)), N * length(data)) + end +end + +function writearray(io::IO, ::Type{UInt8}, col::ToFixedSizeList{UInt8,N}) where {N} + n = 0 + defaultchunk = ntuple(_ -> ArrowTypes.default(UInt8), Val(N)) + data = _fixedsizedata(col) + data isa Vector{NTuple{N,UInt8}} && return _writefixedsizecontiguous(io, data) + for chunk in data + n += _writefixedsizechunk(io, chunk === missing ? defaultchunk : chunk) + end + return n +end + arrowvector(::FixedSizeListKind, x::FixedSizeList, i, nl, fi, de, ded, meta; kw...) = x +function arrowvector( + kind::FixedSizeListKind{N,T}, + x::ArrowTypes.ToArrow, + i, + nl, + fi, + de, + ded, + meta; + kw..., +) where {N,T} + data = _materializeconverted(x) + if data !== x + return arrowvector(kind, data, i, nl, fi, de, ded, meta; kw...) + end + len = length(x) + validity = ValidityBitmap(x) + flat = ToFixedSizeList(x) + if eltype(flat) == UInt8 + child = flat + S = origtype(flat) + else + child = arrowvector(flat, i, nl + 1, fi, de, ded, nothing; kw...) + S = withmissing(eltype(x), NTuple{N,eltype(child)}) + end + return FixedSizeList{S,typeof(child)}(UInt8[], validity, child, len, meta) +end + function arrowvector( ::FixedSizeListKind{N,T}, x, diff --git a/src/arraytypes/list.jl b/src/arraytypes/list.jl index 41ac66f..5d0cf6d 100644 --- a/src/arraytypes/list.jl +++ b/src/arraytypes/list.jl @@ -86,13 +86,17 @@ _codeunits(x::Base.CodeUnits) = x # an AbstractVector version of Iterators.flatten # code based on SentinelArrays.ChainedVector -struct ToList{T,stringtype,A,I} <: AbstractVector{T} - data::Vector{A} # A is AbstractVector or AbstractString +struct ToList{T,stringtype,A<:AbstractVector,I} <: AbstractVector{T} + data::A # A is the outer AbstractVector of AbstractVector or AbstractString inds::Vector{I} + offset::Int end -origtype(::ToList{T,S,A,I}) where {T,S,A,I} = A +origtype(::ToList{T,S,A,I}) where {T,S,A<:AbstractVector,I} = eltype(A) liststringtype(::Type{ToList{T,S,A,I}}) where {T,S,A,I} = S +materializeouter(::Type) = false +materializeouter(input) = materializeouter(typeof(input)) +materializeouterdata(input) = materializeouter(input) ? collect(input) : input function liststringtype(::List{T,O,A}) where {T,O,A} ST = Base.nonmissingtype(T) K = ArrowTypes.ArrowKind(ST) @@ -100,42 +104,80 @@ function liststringtype(::List{T,O,A}) where {T,O,A} end liststringtype(T) = false -function ToList(input; largelists::Bool=false) +@inline function _tolisttraits(input) AT = eltype(input) ST = Base.nonmissingtype(AT) K = ArrowTypes.ArrowKind(ST) stringtype = ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now T = stringtype ? UInt8 : eltype(ST) - len = stringtype ? _ncodeunits : length - data = AT[] + lenf = stringtype ? _ncodeunits : length + return T, stringtype, lenf +end + +@inline function _promotetolistinds(inds::Vector{Int32}, len::Int, filled::Int) + promoted = Vector{Int64}(undef, len + 1) + copyto!(promoted, 1, inds, 1, filled) + return promoted +end + +function _buildtolist(input, data, dataoffset::Int, len::Int; largelists::Bool=false) + T, stringtype, lenf = _tolisttraits(input) I = largelists ? Int64 : Int32 - inds = I[0] - sizehint!(data, length(input)) - sizehint!(inds, length(input)) + inds = Vector{I}(undef, len + 1) + inds[1] = zero(I) totalsize = I(0) - for x in input - if x === missing - push!(data, missing) - else - push!(data, x) - totalsize += len(x) - if I === Int32 && totalsize > 2147483647 + @inbounds for i = 1:len + x = data[i + dataoffset] + if x !== missing + totalsize += lenf(x) + if I === Int32 && totalsize > typemax(Int32) I = Int64 - inds = convert(Vector{Int64}, inds) + inds = _promotetolistinds(inds, len, i) end end - push!(inds, totalsize) + inds[i + 1] = totalsize end - return ToList{T,stringtype,AT,I}(data, inds) + return ToList{T,stringtype,typeof(data),I}(data, inds, dataoffset) +end + +function _tolistgeneric(input; largelists::Bool=false) + data = materializeouterdata(input) + return _buildtolist( + input, + data, + ArrowTypes._offsetshift(data), + length(data); + largelists=largelists, + ) +end + +function ToList(input; largelists::Bool=false) + return _tolistgeneric(input; largelists=largelists) +end + +function ToList(input::ArrowTypes.ToArrow; largelists::Bool=false) + ArrowTypes._needsconvert(input) && return _tolistgeneric(input; largelists=largelists) + data = ArrowTypes._sourcedata(input) + return _buildtolist( + input, + data, + ArrowTypes._sourceoffset(input), + length(input); + largelists=largelists, + ) end Base.IndexStyle(::Type{<:ToList}) = Base.IndexLinear() Base.size(x::ToList{T,S,A,I}) where {T,S,A,I} = (isempty(x.inds) ? zero(I) : x.inds[end],) +@inline _tolistdata(A::ToList) = getfield(A, :data) +@inline _tolistoffset(A::ToList) = getfield(A, :offset) +@inline _tolistchunk(A::ToList, i::Integer) = @inbounds _tolistdata(A)[i + _tolistoffset(A)] + function Base.pointer(A::ToList{UInt8}, i::Integer) chunk = searchsortedfirst(A.inds, i) chunk = chunk > length(A.inds) ? 1 : (chunk - 1) - return pointer(A.data[chunk]) + return pointer(_tolistchunk(A, chunk)) end @inline function index(A::ToList, i::Integer) @@ -149,7 +191,7 @@ Base.@propagate_inbounds function Base.getindex( ) where {T,stringtype} @boundscheck checkbounds(A, i) chunk, ix = index(A, i) - @inbounds x = A.data[chunk] + x = _tolistchunk(A, chunk) return @inbounds stringtype ? _codeunits(x)[ix] : x[ix] end @@ -160,7 +202,7 @@ Base.@propagate_inbounds function Base.setindex!( ) where {T,stringtype} @boundscheck checkbounds(A, i) chunk, ix = index(A, i) - @inbounds x = A.data[chunk] + x = _tolistchunk(A, chunk) if stringtype _codeunits(x)[ix] = v else @@ -180,7 +222,7 @@ end chunk += 1 chunk_len = A.inds[chunk] end - val = A.data[chunk - 1] + val = _tolistchunk(A, chunk - 1) x = stringtype ? _codeunits(val)[1] : val[1] # find next valid index i += 1 @@ -202,7 +244,7 @@ end (i, chunk, chunk_i, chunk_len, len), ) where {T,stringtype} i > len && return nothing - @inbounds val = A.data[chunk - 1] + val = _tolistchunk(A, chunk - 1) @inbounds x = stringtype ? _codeunits(val)[chunk_i] : val[chunk_i] i += 1 if i > chunk_len @@ -219,6 +261,100 @@ end return x, (i, chunk, chunk_i, chunk_len, len) end +@inline function _writeuint8chunk(io::IO, bytes) + GC.@preserve bytes begin + return Base.unsafe_write(io, pointer(bytes), length(bytes)) + end +end + +@inline function _writeutf8chunk(io::IO, chunk::AbstractString) + GC.@preserve chunk begin + return Base.unsafe_write(io, pointer(chunk), ncodeunits(chunk)) + end +end + +@inline function _sizehint_iobuffer!(io::IO, n::Integer) + io isa IOBuffer || return nothing + data = getfield(io, :data) + data isa Vector{UInt8} || return nothing + sizehint!(data, max(length(data), position(io) + n)) + return nothing +end + +function _writearray_tolist_bitstype(io::IO, ::Type{T}, col::ToList{T,false}) where {T} + n = 0 + off = _tolistoffset(col) + data = _tolistdata(col) + if off == 0 + for chunk in data + chunk === missing && continue + n += writearray(io, T, chunk) + end + else + len = length(data) + @inbounds for i = 1:len + chunk = data[i + off] + chunk === missing && continue + n += writearray(io, T, chunk) + end + end + return n +end + +function _writearray_tolist_uint8(io::IO, col::ToList{UInt8,stringtype}) where {stringtype} + n = 0 + _sizehint_iobuffer!(io, length(col)) + off = _tolistoffset(col) + data = _tolistdata(col) + if off == 0 + for chunk in data + chunk === missing && continue + bytes = stringtype ? _codeunits(chunk) : chunk + n += _writeuint8chunk(io, bytes) + end + else + len = length(data) + @inbounds for i = 1:len + chunk = data[i + off] + chunk === missing && continue + bytes = stringtype ? _codeunits(chunk) : chunk + n += _writeuint8chunk(io, bytes) + end + end + return n +end + +function _writearray_tolist_uint8( + io::IO, + col::ToList{UInt8,true,A}, +) where {A<:AbstractVector{<:AbstractString}} + n = 0 + _sizehint_iobuffer!(io, length(col)) + off = _tolistoffset(col) + data = _tolistdata(col) + if off == 0 + for chunk in data + chunk === missing && continue + n += _writeutf8chunk(io, chunk) + end + else + len = length(data) + @inbounds for i = 1:len + chunk = data[i + off] + chunk === missing && continue + n += _writeutf8chunk(io, chunk) + end + end + return n +end + +function writearray(io::IO, ::Type{T}, col::ToList{T,stringtype}) where {T,stringtype} + T === UInt8 && return _writearray_tolist_uint8(io, col) + isbitstype(T) || return _writearrayfallback(io, T, col) + stringtype && return _writearrayfallback(io, T, col) + return _writearray_tolist_bitstype(io, T, col) +end + arrowvector(::ListKind, x::List, i, nl, fi, de, ded, meta; kw...) = x function arrowvector(::ListKind, x, i, nl, fi, de, ded, meta; largelists::Bool=false, kw...) diff --git a/src/arraytypes/map.jl b/src/arraytypes/map.jl index 4216073..c7b8275 100644 --- a/src/arraytypes/map.jl +++ b/src/arraytypes/map.jl @@ -43,8 +43,108 @@ Base.size(l::Map) = (l.ℓ,) end end -keyvalues(KT, ::Missing) = missing -keyvalues(KT, x::AbstractDict) = [KT(k, v) for (k, v) in pairs(x)] +@inline function _promotemapoffsets(offsets::Vector{Int32}, len::Int, filled::Int) + promoted = Vector{Int64}(undef, len + 1) + copyto!(promoted, 1, offsets, 1, filled) + return promoted +end + +function _mapoffsetsandvaluesindexed(::Type{KT}, x; largelists::Bool=false) where {KT} + len = length(x) + O = largelists ? Int64 : Int32 + offsets = Vector{O}(undef, len + 1) + offsets[1] = zero(O) + total = 0 + off = firstindex(x) - 1 + @inbounds for i = 1:len + y = x[i + off] + if y !== missing + total += length(y) + if O === Int32 && total > typemax(Int32) + O = Int64 + offsets = _promotemapoffsets(offsets, len, i) + end + end + offsets[i + 1] = total + end + values = Vector{KT}(undef, total) + pos = 1 + @inbounds for i = 1:len + y = x[i + off] + y === missing && continue + for (k, v) in pairs(y) + values[pos] = KT(k, v) + pos += 1 + end + end + return offsets, values +end + +function mapoffsetsandvalues(::Type{KT}, x; largelists::Bool=false) where {KT} + Base.has_offset_axes(x) && + return _mapoffsetsandvaluesindexed(KT, x; largelists=largelists) + len = length(x) + O = largelists ? Int64 : Int32 + offsets = Vector{O}(undef, len + 1) + offsets[1] = zero(O) + total = 0 + i = 1 + for y in x + if y !== missing + total += length(y) + if O === Int32 && total > typemax(Int32) + O = Int64 + offsets = _promotemapoffsets(offsets, len, i) + end + end + @inbounds offsets[i + 1] = total + i += 1 + end + values = Vector{KT}(undef, total) + pos = 1 + for y in x + y === missing && continue + for (k, v) in pairs(y) + @inbounds values[pos] = KT(k, v) + pos += 1 + end + end + return offsets, values +end + +function mapoffsetsandvalues( + ::Type{KT}, + x::ArrowTypes.ToArrow; + largelists::Bool=false, +) where {KT} + len = length(x) + O = largelists ? Int64 : Int32 + offsets = Vector{O}(undef, len + 1) + offsets[1] = zero(O) + total = 0 + @inbounds for i = 1:len + y = x[i] + if y !== missing + total += length(y) + if O === Int32 && total > typemax(Int32) + O = Int64 + offsets = _promotemapoffsets(offsets, len, i) + end + end + offsets[i + 1] = total + end + values = Vector{KT}(undef, total) + pos = 1 + @inbounds for i = 1:len + y = x[i] + y === missing && continue + for (k, v) in pairs(y) + values[pos] = KT(k, v) + pos += 1 + end + end + return offsets, values +end keyvaluetypes(::Type{NamedTuple{(:key, :value),Tuple{K,V}}}) where {K,V} = (K, V) @@ -67,13 +167,12 @@ function arrowvector(::MapKind, x, i, nl, fi, de, ded, meta; largelists::Bool=fa ), ) KT = KeyValue{KDT,VDT} - VT = Vector{KT} - T = DT !== ET ? Union{Missing,VT} : VT - flat = ToList(T[keyvalues(KT, y) for y in x]; largelists=largelists) - offsets = Offsets(UInt8[], flat.inds) - data = arrowvector(flat, i, nl + 1, fi, de, ded, nothing; largelists=largelists, kw...) + offsetsdata, values = mapoffsetsandvalues(KT, x; largelists=largelists) + offsets = Offsets(UInt8[], offsetsdata) + data = + arrowvector(values, i, nl + 1, fi, de, ded, nothing; largelists=largelists, kw...) K, V = keyvaluetypes(eltype(data)) - return Map{withmissing(ET, Dict{K,V}),eltype(flat.inds),typeof(data)}( + return Map{withmissing(ET, Dict{K,V}),eltype(offsetsdata),typeof(data)}( validity, offsets, data, diff --git a/src/arraytypes/primitive.jl b/src/arraytypes/primitive.jl index 7d86bfe..fbd6483 100644 --- a/src/arraytypes/primitive.jl +++ b/src/arraytypes/primitive.jl @@ -70,6 +70,22 @@ function arrowvector(::PrimitiveKind, x, i, nl, fi, de, ded, meta; kw...) return Primitive(eltype(x), UInt8[], validity, x, length(x), meta) end +function arrowvector( + ::PrimitiveKind, + x::ArrowTypes.ToArrow, + i, + nl, + fi, + de, + ded, + meta; + kw..., +) + data = _materializeconverted(x) + validity = _toarrowvalidity(x, data) + return Primitive(eltype(data), UInt8[], validity, data, length(data), meta) +end + function compress(Z::Meta.CompressionType.T, comp, p::P) where {P<:Primitive} len = length(p) nc = nullcount(p) diff --git a/src/arraytypes/runendencoded.jl b/src/arraytypes/runendencoded.jl new file mode 100644 index 0000000..05946e3 --- /dev/null +++ b/src/arraytypes/runendencoded.jl @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Arrow.RunEndEncoded + +A read-only `ArrowVector` for Arrow Run-End Encoded arrays. Logical indexing is +resolved by binary searching the physical `run_ends` child and then indexing the +corresponding `values` child. +""" +struct RunEndEncoded{T,R,A} <: ArrowVector{T} + run_ends::R + values::A + ℓ::Int + metadata::Union{Nothing,Base.ImmutableDict{String,String}} +end + +Base.size(r::RunEndEncoded) = (r.ℓ,) +Base.copy(r::RunEndEncoded) = collect(r) + +@inline _reephysicalindex(r::RunEndEncoded, i::Integer) = searchsortedfirst(r.run_ends, i) + +function _validaterunendencoded(run_ends, values, len) + nruns = length(run_ends) + nvals = length(values) + nruns == nvals || throw( + ArgumentError( + "invalid Run-End Encoded array: run_ends length $nruns does not match values length $nvals", + ), + ) + if len == 0 + nruns == 0 || throw( + ArgumentError( + "invalid Run-End Encoded array: zero logical length requires zero runs", + ), + ) + elseif nruns == 0 + throw( + ArgumentError( + "invalid Run-End Encoded array: non-zero logical length requires at least one run", + ), + ) + end + last_end = 0 + for (idx, run_end) in enumerate(run_ends) + current_end = Int(run_end) + current_end > last_end || throw( + ArgumentError( + "invalid Run-End Encoded array: run_ends must be strictly increasing positive integers (failed at run $idx)", + ), + ) + last_end = current_end + end + len == 0 || + last_end == len || + throw( + ArgumentError( + "invalid Run-End Encoded array: final run end $last_end does not match logical length $len", + ), + ) + return +end + +function RunEndEncoded(run_ends::R, values::A, len, meta) where {R,A} + _validaterunendencoded(run_ends, values, Int(len)) + T = eltype(values) + return RunEndEncoded{T,R,A}(run_ends, values, Int(len), meta) +end + +function _makerunendencoded(::Type{T}, run_ends::R, values::A, len, meta) where {T,R,A} + _validaterunendencoded(run_ends, values, Int(len)) + return RunEndEncoded{T,R,A}(run_ends, values, Int(len), meta) +end + +@propagate_inbounds function Base.getindex(r::RunEndEncoded{T}, i::Integer) where {T} + @boundscheck checkbounds(r, i) + physical = _reephysicalindex(r, i) + physical <= length(r.values) || throw( + ArgumentError( + "invalid Run-End Encoded array: no physical value found for logical index $i", + ), + ) + return @inbounds ArrowTypes.fromarrow(T, r.values[physical]) +end + +function toarrowvector( + x::RunEndEncoded, + i=1, + de=Dict{Int64,Any}(), + ded=DictEncoding[], + meta=getmetadata(x); + compression::Union{Nothing,Symbol,LZ4FrameCompressor,ZstdCompressor}=nothing, + kw..., +) + throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) +end diff --git a/src/arraytypes/struct.jl b/src/arraytypes/struct.jl index 23a8b64..b66633e 100644 --- a/src/arraytypes/struct.jl +++ b/src/arraytypes/struct.jl @@ -80,13 +80,72 @@ end ToStruct(x::A, j::Integer) where {A} = ToStruct{fieldtype(Base.nonmissingtype(eltype(A)), j),j,A}(x) +@inline _structsource(A::ToStruct) = getfield(A, :data) +@inline _structsourcevalue(A::ToStruct, i::Integer) = @inbounds _structsource(A)[i] + Base.IndexStyle(::Type{<:ToStruct}) = Base.IndexLinear() -Base.size(x::ToStruct) = (length(x.data),) +Base.size(x::ToStruct) = (length(_structsource(x)),) + +@inline _structfield(::Type{T}, x, j) where {T} = + x === missing ? ArrowTypes.default(T) : getfield(x, j) Base.@propagate_inbounds function Base.getindex(A::ToStruct{T,j}, i::Integer) where {T,j} @boundscheck checkbounds(A, i) - @inbounds x = A.data[i] - return x === missing ? ArrowTypes.default(T) : getfield(x, j) + x = _structsourcevalue(A, i) + return _structfield(T, x, j) +end + +function Base.iterate(A::ToStruct{T,j}) where {T,j} + state = iterate(_structsource(A)) + state === nothing && return nothing + x, st = state + return _structfield(T, x, j), st +end + +function Base.iterate(A::ToStruct{T,j}, st) where {T,j} + state = iterate(_structsource(A), st) + state === nothing && return nothing + x, st = state + return _structfield(T, x, j), st +end + +function writearray(io::IO, ::Type{T}, col::ToStruct{T,j}) where {T,j} + isbitstype(T) || return _writearrayfallback(io, T, col) + data = Vector{T}(undef, length(col)) + i = 1 + for x in col + @inbounds data[i] = x + i += 1 + end + return _writearraycontiguous(io, T, data) +end + +function writearray( + io::IO, + ::Type{UInt8}, + col::ToList{UInt8,stringtype,A}, +) where {stringtype,T,j,A<:ToStruct{T,j}} + off = _tolistoffset(col) + off == 0 || return _writearray_tolist_uint8(io, col) + len = length(col) + len <= 1_048_576 || return _writearray_tolist_uint8(io, col) + outer = _tolistdata(col) + data = _structsource(outer) + buf = Vector{UInt8}(undef, len) + pos = 1 + @inbounds for idx in eachindex(data) + chunk = _structfield(T, data[idx], j) + chunk === missing && continue + bytes = stringtype ? _codeunits(chunk) : chunk + for b in bytes + buf[pos] = b + pos += 1 + end + end + written = pos - 1 + GC.@preserve buf begin + return Base.unsafe_write(io, pointer(buf), written) + end end arrowvector(::StructKind, x::Struct, i, nl, fi, de, ded, meta; kw...) = x diff --git a/src/arraytypes/views.jl b/src/arraytypes/views.jl index 0a43f6f..0d23a70 100644 --- a/src/arraytypes/views.jl +++ b/src/arraytypes/views.jl @@ -21,6 +21,17 @@ struct ViewElement offset::Int32 end +const VIEW_ELEMENT_BYTES = sizeof(ViewElement) +const VIEW_LENGTH_BYTES = sizeof(Int32) +const VIEW_INLINE_BYTES = VIEW_ELEMENT_BYTES - VIEW_LENGTH_BYTES + +@inline _viewisinline(length::Integer) = length <= VIEW_INLINE_BYTES +@inline _viewinlinestart(i::Integer) = + ((i - 1) * VIEW_ELEMENT_BYTES) + VIEW_LENGTH_BYTES + 1 +@inline _viewinlineend(i::Integer, length::Integer) = _viewinlinestart(i) + length - 1 +@inline _viewinlineslice(inline::Vector{UInt8}, i::Integer, length::Integer) = + @view inline[_viewinlinestart(i):_viewinlineend(i, length)] + """ Arrow.View @@ -45,12 +56,8 @@ Base.size(l::View) = (l.ℓ,) if S <: Base.CodeUnits # BinaryView return !l.validity[i] ? missing : - v.length < 13 ? - Base.CodeUnits( - StringView( - @view l.inline[(((i - 1) * 16) + 5):(((i - 1) * 16) + 5 + v.length - 1)] - ), - ) : + _viewisinline(v.length) ? + Base.CodeUnits(StringView(_viewinlineslice(l.inline, i, v.length))) : Base.CodeUnits( StringView( @view l.buffers[v.bufindex + 1][(v.offset + 1):(v.offset + v.length)] @@ -59,12 +66,10 @@ Base.size(l::View) = (l.ℓ,) else # Utf8View return !l.validity[i] ? missing : - v.length < 13 ? + _viewisinline(v.length) ? ArrowTypes.fromarrow( T, - StringView( - @view l.inline[(((i - 1) * 16) + 5):(((i - 1) * 16) + 5 + v.length - 1)] - ), + StringView(_viewinlineslice(l.inline, i, v.length)), ) : ArrowTypes.fromarrow( T, diff --git a/src/eltypes.jl b/src/eltypes.jl index 52dbb80..98a2ab9 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -24,6 +24,190 @@ finaljuliatype(T) = T finaljuliatype(::Type{Missing}) = Missing finaljuliatype(::Type{Union{T,Missing}}) where {T} = Union{Missing,finaljuliatype(T)} +const RUN_END_ENCODED_UNSUPPORTED = "Run-End Encoded arrays are not supported yet" +const BOOL8_SYMBOL = Symbol("arrow.bool8") +const JSON_SYMBOL = Symbol("arrow.json") +const OPAQUE_SYMBOL = Symbol("arrow.opaque") +const PARQUET_VARIANT_SYMBOL = Symbol("arrow.parquet.variant") +const FIXED_SHAPE_TENSOR_SYMBOL = Symbol("arrow.fixed_shape_tensor") +const VARIABLE_SHAPE_TENSOR_SYMBOL = Symbol("arrow.variable_shape_tensor") + +@inline _canonicalextensionerror(sym::Symbol, msg::AbstractString) = + throw(ArgumentError("invalid canonical $(String(sym)) extension: $msg")) + +@inline _fieldchildren(field::Meta.Field) = + field.children === nothing ? Meta.Field[] : field.children + +@inline _jsonhaskey(x, key::AbstractString) = haskey(x, key) +@inline _jsonget(x, key::AbstractString) = x[key] + +function _parsecanonicalmetadata(sym::Symbol, metadata::String; required::Bool=false) + isempty(metadata) && + return required ? _canonicalextensionerror(sym, "metadata is required") : nothing + value = try + JSON3.read(metadata) + catch + _canonicalextensionerror(sym, "metadata must be valid JSON") + end + value isa JSON3.Object || + _canonicalextensionerror(sym, "metadata must be a JSON object") + return value +end + +function _parseintvector(sym::Symbol, value, label::AbstractString; allow_null::Bool=false) + value isa AbstractVector || + _canonicalextensionerror(sym, "\"$label\" must be a JSON array") + parsed = Vector{allow_null ? Union{Nothing,Int} : Int}() + for item in value + if allow_null && isnothing(item) + push!(parsed, nothing) + elseif item isa Integer + item >= 0 || + _canonicalextensionerror(sym, "\"$label\" values must be non-negative") + push!(parsed, Int(item)) + else + suffix = allow_null ? "integers or null" : "integers" + _canonicalextensionerror(sym, "\"$label\" must contain only $suffix") + end + end + return parsed +end + +function _parsestringvector(sym::Symbol, value, label::AbstractString) + value isa AbstractVector || + _canonicalextensionerror(sym, "\"$label\" must be a JSON array") + parsed = String[] + for item in value + item isa AbstractString || + _canonicalextensionerror(sym, "\"$label\" must contain only strings") + push!(parsed, String(item)) + end + return parsed +end + +function _validatepermutation(sym::Symbol, permutation::Vector{Int}, ndim::Int) + length(permutation) == ndim || + _canonicalextensionerror(sym, "\"permutation\" must have length $ndim") + length(unique(permutation)) == ndim || + _canonicalextensionerror(sym, "\"permutation\" must not contain duplicates") + return permutation +end + +function _extractdimensionalmetadata( + sym::Symbol, + metadata; + ndim::Union{Nothing,Int}=nothing, +) + metadata === nothing && return (nothing, nothing, nothing) + dim_names = + _jsonhaskey(metadata, "dim_names") ? + _parsestringvector(sym, _jsonget(metadata, "dim_names"), "dim_names") : nothing + permutation = + _jsonhaskey(metadata, "permutation") ? + _parseintvector(sym, _jsonget(metadata, "permutation"), "permutation") : nothing + uniform_shape = + _jsonhaskey(metadata, "uniform_shape") ? + _parseintvector( + sym, + _jsonget(metadata, "uniform_shape"), + "uniform_shape"; + allow_null=true, + ) : nothing + if ndim !== nothing + dim_names !== nothing && length(dim_names) == ndim || + isnothing(dim_names) || + _canonicalextensionerror(sym, "\"dim_names\" must have length $ndim") + permutation !== nothing && _validatepermutation(sym, permutation, ndim) + uniform_shape !== nothing && length(uniform_shape) == ndim || + isnothing(uniform_shape) || + _canonicalextensionerror(sym, "\"uniform_shape\" must have length $ndim") + end + return dim_names, permutation, uniform_shape +end + +@inline _isliststoragetype(x) = + x isa Union{Meta.List,Meta.LargeList,Meta.ListView,Meta.LargeListView} + +@inline _isbinarystoragetype(x) = + x isa Union{Meta.Binary,Meta.LargeBinary,Meta.BinaryView,Meta.FixedSizeBinary} + +function _validateparquetvariant(field::Meta.Field, metadata::String) + isempty(metadata) || _canonicalextensionerror( + PARQUET_VARIANT_SYMBOL, + "metadata must be the empty string", + ) + field + return +end + +function _validatefixedshapetensor(field::Meta.Field, metadata::String) + meta = _parsecanonicalmetadata(FIXED_SHAPE_TENSOR_SYMBOL, metadata; required=true) + _jsonhaskey(meta, "shape") || + _canonicalextensionerror(FIXED_SHAPE_TENSOR_SYMBOL, "\"shape\" is required") + shape = _parseintvector(FIXED_SHAPE_TENSOR_SYMBOL, _jsonget(meta, "shape"), "shape") + dim_names, permutation, _ = + _extractdimensionalmetadata(FIXED_SHAPE_TENSOR_SYMBOL, meta; ndim=length(shape)) + field.type isa Meta.FixedSizeList || _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "storage must be a FixedSizeList", + ) + length(collect(_fieldchildren(field))) == 1 || _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "storage must contain exactly one child field", + ) + expected = isempty(shape) ? 1 : prod(shape) + Int(field.type.listSize) == expected || _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "\"shape\" product $expected does not match FixedSizeList size $(field.type.listSize)", + ) + dim_names + permutation + return +end + +function _validatevariableshapetensor(field::Meta.Field, metadata::String) + field.type isa Meta.Struct || + _canonicalextensionerror(VARIABLE_SHAPE_TENSOR_SYMBOL, "storage must be a Struct") + children = Dict(String(child.name) => child for child in collect(_fieldchildren(field))) + keys(children) == Set(("data", "shape")) || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "storage must contain exactly \"data\" and \"shape\" fields", + ) + data_field = children["data"] + shape_field = children["shape"] + _isliststoragetype(data_field.type) || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"data\" field must use list storage", + ) + length(collect(_fieldchildren(data_field))) == 1 || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"data\" field must contain exactly one child field", + ) + shape_field.type isa Meta.FixedSizeList || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" field must use FixedSizeList storage", + ) + shape_children = collect(_fieldchildren(shape_field)) + length(shape_children) == 1 || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" field must contain exactly one child field", + ) + shape_value = only(shape_children) + shape_value.type isa Meta.Int || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" values must use Int32 storage", + ) + (shape_value.type.bitWidth == 32 && shape_value.type.is_signed) || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" values must use signed Int32 storage", + ) + ndim = Int(shape_field.type.listSize) + meta = _parsecanonicalmetadata(VARIABLE_SHAPE_TENSOR_SYMBOL, metadata) + _extractdimensionalmetadata(VARIABLE_SHAPE_TENSOR_SYMBOL, meta; ndim=ndim) + return +end + """ Given a FlatBuffers.Builder and a Julia column or column eltype, Write the field.type flatbuffer definition of the eltype @@ -41,19 +225,25 @@ end function juliaeltype(f::Meta.Field, meta::AbstractDict{String,String}, convert::Bool) TT = juliaeltype(f, convert) - !convert && return TT - T = finaljuliatype(TT) - if haskey(meta, "ARROW:extension:name") - typename = meta["ARROW:extension:name"] - metadata = get(meta, "ARROW:extension:metadata", "") - JT = ArrowTypes.JuliaType(Val(Symbol(typename)), maybemissing(TT), metadata) + spec = _extensionspec(meta) + if spec !== nothing + _validatebuiltinextension(spec, f) + !convert && return TT + T = finaljuliatype(TT) + storageT = + spec.name === TIMESTAMP_WITH_OFFSET_SYMBOL ? + maybemissing(juliaeltype(f, false)) : maybemissing(TT) + JT = _resolveextensionjuliatype(spec, storageT) if JT !== nothing return f.nullable ? Union{JT,Missing} : JT else - @warn "unsupported ARROW:extension:name type: \"$typename\", arrow type = $TT" maxlog = + typename = _extensiontypename(spec) + @warn "unsupported $(EXTENSION_NAME_KEY) type: \"$typename\", arrow type = $TT" maxlog = 1 _id = hash((:juliaeltype, typename, TT)) end end + !convert && return TT + T = finaljuliatype(TT) return something(TT, T) end @@ -108,6 +298,105 @@ function arrowtype(b, ::Type{T}) where {T<:Integer} return Meta.Int, Meta.intEnd(b), nothing end +struct Bool8 + value::Bool +end + +Bool8(x::Integer) = Bool8(!iszero(x)) + +Base.Bool(x::Bool8) = getfield(x, :value) +Base.convert(::Type{Bool}, x::Bool8) = Bool(x) +Base.convert(::Type{Int8}, x::Bool8) = Int8(Bool(x)) +Base.zero(::Type{Bool8}) = Bool8(false) +Base.:(==)(x::Bool8, y::Bool8) = Bool(x) == Bool(y) +Base.isequal(x::Bool8, y::Bool8) = isequal(Bool(x), Bool(y)) + +ArrowTypes.ArrowType(::Type{Bool8}) = _builtinarrowtype(Bool8) +ArrowTypes.toarrow(x::Bool8) = _builtintoarrow(x) +ArrowTypes.arrowname(::Type{Bool8}) = _builtinarrowname(Bool8) +ArrowTypes.JuliaType(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = + _builtinextensionjuliatype(Val(BOOL8_SYMBOL), Int8, metadata) +ArrowTypes.fromarrow(::Type{Bool8}, x::Int8) = _builtinfromarrow(Bool8, x) +ArrowTypes.default(::Type{Bool8}) = _builtindefault(Bool8) + +function writearray( + io::IO, + ::Type{Int8}, + col::ArrowTypes.ToArrow{Int8,A}, +) where {A<:AbstractVector{Bool8}} + data = ArrowTypes._sourcedata(col) + strides(data) == (1,) || return _writearrayfallback(io, Int8, col) + return Base.write(io, reinterpret(Int8, data)) +end + +struct JSONText{S<:AbstractString} + value::S +end + +Base.String(x::JSONText) = String(getfield(x, :value)) +Base.convert(::Type{String}, x::JSONText) = String(x) +Base.:(==)(x::JSONText, y::JSONText) = getfield(x, :value) == getfield(y, :value) +Base.isequal(x::JSONText, y::JSONText) = isequal(getfield(x, :value), getfield(y, :value)) + +ArrowTypes.ArrowType(::Type{JSONText{S}}) where {S<:AbstractString} = + _builtinarrowtype(JSONText{S}) +ArrowTypes.toarrow(x::JSONText) = _builtintoarrow(x) +ArrowTypes.arrowname(::Type{JSONText{S}}) where {S<:AbstractString} = + _builtinarrowname(JSONText{S}) +ArrowTypes.JuliaType( + ::Val{JSON_SYMBOL}, + ::Type{S}, + metadata::String, +) where {S<:AbstractString} = _builtinextensionjuliatype(Val(JSON_SYMBOL), S, metadata) +ArrowTypes.fromarrow(::Type{JSONText{String}}, ptr::Ptr{UInt8}, len::Int) = + _builtinfromarrow(JSONText{String}, ptr, len) +ArrowTypes.fromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = + _builtinfromarrow(JSONText{S}, x) +ArrowTypes.default(::Type{JSONText{S}}) where {S<:AbstractString} = + _builtindefault(JSONText{S}) + +ArrowTypes.JuliaType(::Val{OPAQUE_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(OPAQUE_SYMBOL), S, metadata) +ArrowTypes.JuliaType(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(PARQUET_VARIANT_SYMBOL), S, metadata) +ArrowTypes.JuliaType(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(FIXED_SHAPE_TENSOR_SYMBOL), S, metadata) +ArrowTypes.JuliaType(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(VARIABLE_SHAPE_TENSOR_SYMBOL), S, metadata) + +@inline function _jsonstringliteral(x::AbstractString) + return '"' * escape_string(x) * '"' +end + +opaquemetadata(type_name::AbstractString, vendor_name::AbstractString) = + _builtinopaquemetadata(type_name, vendor_name) + +variantmetadata() = _builtinvariantmetadata() + +function fixedshapetensormetadata( + shape::AbstractVector{<:Integer}; + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + return _builtinfixedshapetensormetadata( + shape; + dim_names=dim_names, + permutation=permutation, + ) +end + +function variableshapetensormetadata(; + uniform_shape::Union{Nothing,AbstractVector}=nothing, + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + return _builtinvariableshapetensormetadata(; + uniform_shape=uniform_shape, + dim_names=dim_names, + permutation=permutation, + ) +end + # primitive types function juliaeltype(f::Meta.Field, fp::Meta.FloatingPoint, convert) if fp.precision == Meta.Precision.HALF @@ -265,6 +554,17 @@ end Base.zero(::Type{Timestamp{U,T}}) where {U,T} = Timestamp{U,T}(Int64(0)) +struct TimestampWithOffset{U} + timestamp::Timestamp{U,:UTC} + offset_minutes::Int16 +end + +TimestampWithOffset(timestamp::Timestamp{U,:UTC}, offset_minutes::Integer) where {U} = + TimestampWithOffset{U}(timestamp, Int16(offset_minutes)) + +Base.zero(::Type{TimestampWithOffset{U}}) where {U} = + TimestampWithOffset{U}(zero(Timestamp{U,:UTC}), Int16(0)) + function juliaeltype(f::Meta.Field, x::Meta.Timestamp, convert) return Timestamp{x.unit,x.timezone === nothing ? nothing : Symbol(x.timezone)} end @@ -325,31 +625,64 @@ ArrowTypes.fromarrow(::Type{Dates.DateTime}, x::Date{Meta.DateUnit.MILLISECOND,I convert(Dates.DateTime, x) ArrowTypes.default(::Type{Dates.DateTime}) = Dates.DateTime(1, 1, 1, 1, 1, 1) -ArrowTypes.ArrowType(::Type{ZonedDateTime}) = Timestamp -ArrowTypes.toarrow(x::ZonedDateTime) = - convert(Timestamp{Meta.TimeUnit.MILLISECOND,Symbol(x.timezone)}, x) +ArrowTypes.ArrowType(::Type{ZonedDateTime}) = _builtinarrowtype(ZonedDateTime) +ArrowTypes.toarrow(x::ZonedDateTime) = _builtintoarrow(x) const ZONEDDATETIME_SYMBOL = Symbol("JuliaLang.ZonedDateTime-UTC") -ArrowTypes.arrowname(::Type{ZonedDateTime}) = ZONEDDATETIME_SYMBOL -ArrowTypes.JuliaType(::Val{ZONEDDATETIME_SYMBOL}, S) = ZonedDateTime -ArrowTypes.fromarrow(::Type{ZonedDateTime}, x::Timestamp) = convert(ZonedDateTime, x) -ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = - TimeZones.ZonedDateTime(1, 1, 1, 1, 1, 1, TimeZones.tz"UTC") +ArrowTypes.arrowname(::Type{ZonedDateTime}) = _builtinarrowname(ZonedDateTime) +ArrowTypes.JuliaType(::Val{ZONEDDATETIME_SYMBOL}, S) = + _builtinextensionjuliatype(Val(ZONEDDATETIME_SYMBOL), S) +ArrowTypes.fromarrow(::Type{ZonedDateTime}, x::Timestamp) = + _builtinfromarrow(ZonedDateTime, x) +ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = _builtindefault(ZonedDateTime) + +const TIMESTAMP_WITH_OFFSET_SYMBOL = Symbol("arrow.timestamp_with_offset") +ArrowTypes.ArrowType(::Type{TimestampWithOffset{U}}) where {U} = + _builtinarrowtype(TimestampWithOffset{U}) +ArrowTypes.toarrow(x::TimestampWithOffset{U}) where {U} = _builtintoarrow(x) +ArrowTypes.arrowname(::Type{TimestampWithOffset{U}}) where {U} = + _builtinarrowname(TimestampWithOffset{U}) +ArrowTypes.JuliaType( + ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, + ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, + metadata::String, +) where {U} = _builtinextensionjuliatype( + Val(TIMESTAMP_WITH_OFFSET_SYMBOL), + NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}, + metadata, +) +ArrowTypes.default(::Type{TimestampWithOffset{U}}) where {U} = + _builtindefault(TimestampWithOffset{U}) +ArrowTypes.fromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:timestamp, :offset_minutes)}, + timestamp::Timestamp{U,:UTC}, + offset_minutes::Int16, +) where {U} = _builtinfromarrowstruct( + TimestampWithOffset{U}, + Val((:timestamp, :offset_minutes)), + timestamp, + offset_minutes, +) +ArrowTypes.fromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:offset_minutes, :timestamp)}, + offset_minutes::Int16, + timestamp::Timestamp{U,:UTC}, +) where {U} = _builtinfromarrowstruct( + TimestampWithOffset{U}, + Val((:offset_minutes, :timestamp)), + offset_minutes, + timestamp, +) # Backwards compatibility: older versions of Arrow saved ZonedDateTime's with this metdata: const OLD_ZONEDDATETIME_SYMBOL = Symbol("JuliaLang.ZonedDateTime") # and stored the local time instead of the UTC time. struct LocalZonedDateTime end -ArrowTypes.JuliaType(::Val{OLD_ZONEDDATETIME_SYMBOL}, S) = LocalZonedDateTime -function ArrowTypes.fromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} - (U === Meta.TimeUnit.MICROSECOND || U == Meta.TimeUnit.NANOSECOND) && - warntimestamp(U, ZonedDateTime) - return ZonedDateTime( - Dates.DateTime( - Dates.UTM(Int64(Dates.toms(periodtype(U)(x.x)) + UNIX_EPOCH_DATETIME)), - ), - TimeZone(String(TZ)), - ) -end +ArrowTypes.JuliaType(::Val{OLD_ZONEDDATETIME_SYMBOL}, S) = + _builtinextensionjuliatype(Val(OLD_ZONEDDATETIME_SYMBOL), S) +ArrowTypes.fromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} = + _builtinfromarrow(LocalZonedDateTime, x) """ Arrow.ToTimestamp(x::AbstractVector{ZonedDateTime}) @@ -390,6 +723,10 @@ function juliaeltype(f::Meta.Field, x::Meta.Interval, convert) return Interval{x.unit,bitwidth(x.unit)} end +function juliaeltype(f::Meta.Field, x::Meta.RunEndEncoded, convert) + return juliaeltype(f.children[2], buildmetadata(f.children[2]), convert) +end + function arrowtype(b, ::Type{Interval{U,T}}) where {U,T} Meta.intervalStart(b) Meta.intervalAddUnit(b, U) diff --git a/src/flight/Flight.jl b/src/flight/Flight.jl new file mode 100644 index 0000000..f6ae04a --- /dev/null +++ b/src/flight/Flight.jl @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module Flight + +using Base64 +using ProtoBuf +using gRPCClient +using Tables + +const ArrowParent = parentmodule(@__MODULE__) + +include("exports.jl") +include("protocol.jl") +include("client.jl") +include("server.jl") +include("convert.jl") + +end # module Flight diff --git a/src/flight/client.jl b/src/flight/client.jl new file mode 100644 index 0000000..dc196bd --- /dev/null +++ b/src/flight/client.jl @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("client/constants.jl") +include("client/locations.jl") +include("client/types.jl") +include("client/headers.jl") +include("client/transport.jl") +include("client/protocol_clients.jl") +include("client/auth.jl") +include("client/rpc_methods.jl") diff --git a/src/flight/client/auth.jl b/src/flight/client/auth.jl new file mode 100644 index 0000000..9750cbb --- /dev/null +++ b/src/flight/client/auth.jl @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +handshake( + client::Client, + request::Channel{Protocol.HandshakeRequest}, + response::Channel{Protocol.HandshakeResponse}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _handshake_client(client; kwargs...), + request, + response, + headers=_merge_headers(client, headers), +) + +function handshake( + client::Client; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + request = Channel{Protocol.HandshakeRequest}(request_capacity) + response = Channel{Protocol.HandshakeResponse}(response_capacity) + req = handshake(client, request, response; headers=headers, kwargs...) + return req, request, response +end + +function authenticate( + client::Client, + requests::AbstractVector{<:Protocol.HandshakeRequest}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + req, request_channel, response_channel = handshake(client; headers=headers, kwargs...) + for request in requests + put!(request_channel, request) + end + close(request_channel) + + responses = collect(response_channel) + gRPCClient.grpc_async_await(req) + + isempty(responses) && + throw(ArgumentError("Arrow Flight handshake returned no response messages")) + + return withtoken(client, responses[end].payload), responses +end + +function authenticate( + client::Client, + payloads::AbstractVector{<:AbstractVector{UInt8}}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + requests = [ + Protocol.HandshakeRequest(UInt64(0), Vector{UInt8}(payload)) for payload in payloads + ] + return authenticate(client, requests; headers=headers, kwargs...) +end + +function authenticate( + client::Client, + username::AbstractString, + password::AbstractString; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return authenticate( + client, + [Vector{UInt8}(codeunits(username)), Vector{UInt8}(codeunits(password))]; + headers=headers, + kwargs..., + ) +end diff --git a/src/flight/client/constants.jl b/src/flight/client/constants.jl new file mode 100644 index 0000000..f527698 --- /dev/null +++ b/src/flight/client/constants.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const DEFAULT_MAX_MESSAGE_LENGTH = 4 * 1024 * 1024 +const DEFAULT_STREAM_BUFFER = 16 +const HeaderValue = Union{String,Vector{UInt8}} +const HeaderPair = Pair{String,HeaderValue} +const AUTH_TOKEN_HEADER = "auth-token-bin" diff --git a/src/flight/client/headers.jl b/src/flight/client/headers.jl new file mode 100644 index 0000000..6f1ca7f --- /dev/null +++ b/src/flight/client/headers.jl @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +_normalize_header_value(value::AbstractString) = String(value) +_normalize_header_value(value::AbstractVector{UInt8}) = Vector{UInt8}(value) +function _normalize_header_value(value) + throw( + ArgumentError( + "Arrow Flight header values must be strings or byte vectors, got $(typeof(value))", + ), + ) +end + +function _normalize_headers(headers::AbstractVector{<:Pair}) + normalized = HeaderPair[] + for header in headers + push!(normalized, String(first(header)) => _normalize_header_value(last(header))) + end + return normalized +end + +withheaders(client::Client, headers::Pair...) = withheaders(client, collect(headers)) + +function withheaders(client::Client, headers::AbstractVector{<:Pair}) + merged_headers = copy(client.headers) + append!(merged_headers, _normalize_headers(headers)) + return _rebuild_client(client; headers=merged_headers) +end + +withtoken(client::Client, token::AbstractString) = + withtoken(client, Vector{UInt8}(codeunits(token))) +withtoken(client::Client, token::AbstractVector{UInt8}) = + _withreplacedheader(client, AUTH_TOKEN_HEADER => Vector{UInt8}(token)) + +function _withreplacedheader(client::Client, header::Pair) + normalized_header = String(first(header)) => _normalize_header_value(last(header)) + name = lowercase(first(normalized_header)) + filtered_headers = HeaderPair[ + existing for existing in client.headers if lowercase(first(existing)) != name + ] + push!(filtered_headers, normalized_header) + return _rebuild_client(client; headers=filtered_headers) +end + +function _header_lines(headers::AbstractVector{HeaderPair}) + lines = String[] + for (name, value) in headers + isempty(name) && throw(ArgumentError("Arrow Flight header names must not be empty")) + any(ch -> ch == '\r' || ch == '\n', name) && + throw(ArgumentError("Arrow Flight header names must not contain newlines")) + rendered_value = _render_header_value(name, value) + any(ch -> ch == '\r' || ch == '\n', rendered_value) && + throw(ArgumentError("Arrow Flight header values must not contain newlines")) + push!(lines, string(name, ": ", rendered_value)) + end + return lines +end + +function _render_header_value(name::String, value::String) + if endswith(lowercase(name), "-bin") + return Base64.base64encode(codeunits(value)) + end + return value +end + +function _render_header_value(name::String, value::Vector{UInt8}) + endswith(lowercase(name), "-bin") || + throw(ArgumentError("Arrow Flight binary header values require a '-bin' suffix")) + return Base64.base64encode(value) +end + +function _merge_headers(client::Client, headers::AbstractVector{<:Pair}=HeaderPair[]) + merged_headers = copy(client.headers) + append!(merged_headers, _normalize_headers(headers)) + return merged_headers +end diff --git a/src/flight/client/locations.jl b/src/flight/client/locations.jl new file mode 100644 index 0000000..846d72c --- /dev/null +++ b/src/flight/client/locations.jl @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _parse_location(uri::String) + match_result = match( + r"^(grpc\+tls|grpc\+tcp|grpc|https|http)://(\[[^\]]+\]|[^:/?#]+):([0-9]+)(/.*)?$", + uri, + ) + isnothing(match_result) && + throw(ArgumentError("unsupported Arrow Flight location URI: $uri")) + scheme = match_result.captures[1] + host = match_result.captures[2] + port = parse(Int64, match_result.captures[3]) + secure = scheme == "grpc+tls" || scheme == "https" + if startswith(host, "[") && endswith(host, "]") + host = host[2:(end - 1)] + end + return secure, host, port +end diff --git a/src/flight/client/methods/actions.jl b/src/flight/client/methods/actions.jl new file mode 100644 index 0000000..7de2201 --- /dev/null +++ b/src/flight/client/methods/actions.jl @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +doaction( + client::Client, + action::Protocol.Action, + response::Channel{Protocol.Result}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doaction_client(client; kwargs...), + action, + response; + headers=_merge_headers(client, headers), +) + +function doaction( + client::Client, + action::Protocol.Action; + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.Result}(response_capacity) + req = doaction(client, action, response; headers=headers, kwargs...) + return req, response +end + +function listactions( + client::Client, + response::Channel{Protocol.ActionType}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_async_request( + client, + _listactions_client(client; kwargs...), + Protocol.Empty(), + response, + headers=_merge_headers(client, headers), + ) +end + +function listactions( + client::Client; + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.ActionType}(response_capacity) + req = listactions(client, response; headers=headers, kwargs...) + return req, response +end diff --git a/src/flight/client/methods/data.jl b/src/flight/client/methods/data.jl new file mode 100644 index 0000000..3015977 --- /dev/null +++ b/src/flight/client/methods/data.jl @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +doget( + client::Client, + ticket::Protocol.Ticket, + response::Channel{Protocol.FlightData}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doget_client(client; kwargs...), + ticket, + response; + headers=_merge_headers(client, headers), +) + +function doget( + client::Client, + ticket::Protocol.Ticket; + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.FlightData}(response_capacity) + req = doget(client, ticket, response; headers=headers, kwargs...) + return req, response +end + +doput( + client::Client, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.PutResult}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doput_client(client; kwargs...), + request, + response; + headers=_merge_headers(client, headers), +) + +function doput( + client::Client; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + response = Channel{Protocol.PutResult}(response_capacity) + req = doput(client, request, response; headers=headers, kwargs...) + return req, request, response +end + +function doput( + client::Client, + source, + response::Channel{Protocol.PutResult}; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + grpc_request = doput(client, request, response; headers=headers, kwargs...) + producer = errormonitor( + Threads.@spawn putflightdata!( + request, + source; + close=true, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + app_metadata=app_metadata, + ) + ) + return FlightAsyncRequest(grpc_request, producer) +end + +function doput( + client::Client, + source; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, + kwargs..., +) + response = Channel{Protocol.PutResult}(response_capacity) + req = doput( + client, + source, + response; + request_capacity=request_capacity, + headers=headers, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + app_metadata=app_metadata, + kwargs..., + ) + return req, response +end + +doexchange( + client::Client, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.FlightData}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doexchange_client(client; kwargs...), + request, + response, + headers=_merge_headers(client, headers), +) + +function doexchange( + client::Client; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + response = Channel{Protocol.FlightData}(response_capacity) + req = doexchange(client, request, response; headers=headers, kwargs...) + return req, request, response +end + +function doexchange( + client::Client, + source, + response::Channel{Protocol.FlightData}; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + grpc_request = doexchange(client, request, response; headers=headers, kwargs...) + producer = errormonitor( + Threads.@spawn putflightdata!( + request, + source; + close=true, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + app_metadata=app_metadata, + ) + ) + return FlightAsyncRequest(grpc_request, producer) +end + +function doexchange( + client::Client, + source; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, + kwargs..., +) + response = Channel{Protocol.FlightData}(response_capacity) + req = doexchange( + client, + source, + response; + request_capacity=request_capacity, + headers=headers, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + app_metadata=app_metadata, + kwargs..., + ) + return req, response +end diff --git a/src/flight/client/methods/discovery.jl b/src/flight/client/methods/discovery.jl new file mode 100644 index 0000000..5635154 --- /dev/null +++ b/src/flight/client/methods/discovery.jl @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +listflights( + client::Client, + criteria::Protocol.Criteria, + response::Channel{Protocol.FlightInfo}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _listflights_client(client; kwargs...), + criteria, + response, + headers=_merge_headers(client, headers), +) + +function listflights( + client::Client, + criteria::Protocol.Criteria=Protocol.Criteria(UInt8[]); + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.FlightInfo}(response_capacity) + req = listflights(client, criteria, response; headers=headers, kwargs...) + return req, response +end + +function getflightinfo( + client::Client, + descriptor::Protocol.FlightDescriptor; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_sync_request( + client, + _getflightinfo_client(client; kwargs...), + descriptor; + headers=_merge_headers(client, headers), + ) +end + +function pollflightinfo( + client::Client, + descriptor::Protocol.FlightDescriptor; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_sync_request( + client, + _pollflightinfo_client(client; kwargs...), + descriptor; + headers=_merge_headers(client, headers), + ) +end + +function getschema( + client::Client, + descriptor::Protocol.FlightDescriptor; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_sync_request( + client, + _getschema_client(client; kwargs...), + descriptor; + headers=_merge_headers(client, headers), + ) +end diff --git a/src/flight/client/protocol_clients.jl b/src/flight/client/protocol_clients.jl new file mode 100644 index 0000000..5524cbb --- /dev/null +++ b/src/flight/client/protocol_clients.jl @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +_handshake_client(client::Client; kwargs...) = Protocol.FlightService_Handshake_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_listflights_client(client::Client; kwargs...) = Protocol.FlightService_ListFlights_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_getflightinfo_client(client::Client; kwargs...) = + Protocol.FlightService_GetFlightInfo_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., + ) + +_pollflightinfo_client(client::Client; kwargs...) = + Protocol.FlightService_PollFlightInfo_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., + ) + +_getschema_client(client::Client; kwargs...) = Protocol.FlightService_GetSchema_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doget_client(client::Client; kwargs...) = Protocol.FlightService_DoGet_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doput_client(client::Client; kwargs...) = Protocol.FlightService_DoPut_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doexchange_client(client::Client; kwargs...) = Protocol.FlightService_DoExchange_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doaction_client(client::Client; kwargs...) = Protocol.FlightService_DoAction_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_listactions_client(client::Client; kwargs...) = Protocol.FlightService_ListActions_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) diff --git a/src/flight/client/rpc_methods.jl b/src/flight/client/rpc_methods.jl new file mode 100644 index 0000000..7e8e2cd --- /dev/null +++ b/src/flight/client/rpc_methods.jl @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("methods/discovery.jl") +include("methods/data.jl") +include("methods/actions.jl") diff --git a/src/flight/client/transport.jl b/src/flight/client/transport.jl new file mode 100644 index 0000000..5b6ede7 --- /dev/null +++ b/src/flight/client/transport.jl @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _append_headers_unlocked!( + req::gRPCClient.gRPCRequest, + headers::AbstractVector{HeaderPair}, +) + isempty(headers) && return req + for header_line in _header_lines(headers) + req.headers = gRPCClient.curl_slist_append(req.headers, header_line) + end + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_HTTPHEADER, req.headers) + return req +end + +function _apply_tls_options_unlocked!(client::Client, req::gRPCClient.gRPCRequest) + if !client.secure + return req + end + + if client.disable_server_verification + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYPEER, Clong(0)) + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYHOST, Clong(0)) + else + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYPEER, Clong(1)) + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYHOST, Clong(2)) + end + + !isnothing(client.tls_root_certs) && gRPCClient.curl_easy_setopt( + req.easy, + gRPCClient.CURLOPT_CAINFO, + client.tls_root_certs, + ) + !isnothing(client.cert_chain) && + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSLCERT, client.cert_chain) + !isnothing(client.private_key) && + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSLKEY, client.private_key) + !isnothing(client.key_password) && gRPCClient.curl_easy_setopt( + req.easy, + gRPCClient.CURLOPT_KEYPASSWD, + client.key_password, + ) + + return req +end + +function _apply_client_options_unlocked!( + client::Client, + req::gRPCClient.gRPCRequest, + headers::AbstractVector{HeaderPair}, +) + _append_headers_unlocked!(req, headers) + return _apply_tls_options_unlocked!(client, req) +end + +function _grpc_sync_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,false,TResponse,false}, + request::TRequest; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + req = lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request) + _apply_client_options_unlocked!(client, req, headers) + end + return gRPCClient.grpc_async_await(rpc_client, req) +end + +function _grpc_async_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,false,TResponse,true}, + request::TRequest, + response::Channel{TResponse}; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + return lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request, response) + _apply_client_options_unlocked!(client, req, headers) + end +end + +function _grpc_async_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,true,TResponse,false}, + request::Channel{TRequest}, + response::Channel{TResponse}; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + return lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request, response) + _apply_client_options_unlocked!(client, req, headers) + end +end + +function _grpc_async_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,true,TResponse,true}, + request::Channel{TRequest}, + response::Channel{TResponse}; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + return lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request, response) + _apply_client_options_unlocked!(client, req, headers) + end +end + +struct FlightAsyncRequest{R} + request::R + producer::Union{Nothing,Task} +end + +function Base.wait(req::FlightAsyncRequest) + producer = getfield(req, :producer) + isnothing(producer) || wait(producer) + return wait(getfield(req, :request)) +end + +function gRPCClient.grpc_async_await(req::FlightAsyncRequest) + producer = getfield(req, :producer) + isnothing(producer) || wait(producer) + return gRPCClient.grpc_async_await(getfield(req, :request)) +end + +function gRPCClient.grpc_async_await( + client::gRPCClient.gRPCServiceClient{TRequest,true,TResponse,false}, + req::FlightAsyncRequest, +) where {TRequest<:Any,TResponse<:Any} + producer = getfield(req, :producer) + isnothing(producer) || wait(producer) + return gRPCClient.grpc_async_await(client, getfield(req, :request)) +end + +_default_rpc_options(client::Client) = ( + secure=client.secure, + grpc=client.grpc, + deadline=client.deadline, + keepalive=client.keepalive, + max_send_message_length=client.max_send_message_length, + max_recieve_message_length=client.max_recieve_message_length, +) + +_rpc_options(client::Client; kwargs...) = + merge(_default_rpc_options(client), NamedTuple(kwargs)) diff --git a/src/flight/client/types.jl b/src/flight/client/types.jl new file mode 100644 index 0000000..77d0bdd --- /dev/null +++ b/src/flight/client/types.jl @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +struct Client + host::String + port::Int64 + secure::Bool + grpc::gRPCClient.gRPCCURL + deadline::Float64 + keepalive::Float64 + max_send_message_length::Int64 + max_recieve_message_length::Int64 + headers::Vector{HeaderPair} + tls_root_certs::Union{Nothing,String} + cert_chain::Union{Nothing,String} + private_key::Union{Nothing,String} + key_password::Union{Nothing,String} + disable_server_verification::Bool +end + +function Client( + host, + port; + secure::Bool=false, + grpc::gRPCClient.gRPCCURL=gRPCClient.grpc_global_handle(), + deadline::Real=10, + keepalive::Real=60, + max_send_message_length::Integer=DEFAULT_MAX_MESSAGE_LENGTH, + max_recieve_message_length::Integer=DEFAULT_MAX_MESSAGE_LENGTH, + headers::AbstractVector{<:Pair}=HeaderPair[], + tls_root_certs::Union{Nothing,AbstractString}=nothing, + cert_chain::Union{Nothing,AbstractString}=nothing, + private_key::Union{Nothing,AbstractString}=nothing, + key_password::Union{Nothing,AbstractString}=nothing, + disable_server_verification::Bool=false, +) + Client( + String(host), + Int64(port), + secure, + grpc, + Float64(deadline), + Float64(keepalive), + Int64(max_send_message_length), + Int64(max_recieve_message_length), + _normalize_headers(headers), + isnothing(tls_root_certs) ? nothing : String(tls_root_certs), + isnothing(cert_chain) ? nothing : String(cert_chain), + isnothing(private_key) ? nothing : String(private_key), + isnothing(key_password) ? nothing : String(key_password), + disable_server_verification, + ) +end + +Client(location::Protocol.Location; kwargs...) = Client(location.uri; kwargs...) + +function Client(uri::AbstractString; kwargs...) + secure, host, port = _parse_location(String(uri)) + Client(host, port; secure=secure, kwargs...) +end + +function _rebuild_client(client::Client; headers::AbstractVector{<:Pair}=client.headers) + return Client( + client.host, + client.port; + secure=client.secure, + grpc=client.grpc, + deadline=client.deadline, + keepalive=client.keepalive, + max_send_message_length=client.max_send_message_length, + max_recieve_message_length=client.max_recieve_message_length, + headers=headers, + tls_root_certs=client.tls_root_certs, + cert_chain=client.cert_chain, + private_key=client.private_key, + key_password=client.key_password, + disable_server_verification=client.disable_server_verification, + ) +end diff --git a/src/flight/convert.jl b/src/flight/convert.jl new file mode 100644 index 0000000..b6d8289 --- /dev/null +++ b/src/flight/convert.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("convert/constants.jl") +include("convert/framing.jl") +include("convert/schema.jl") +include("convert/streaming.jl") +include("convert/flightdata.jl") diff --git a/src/flight/convert/constants.jl b/src/flight/convert/constants.jl new file mode 100644 index 0000000..80db62e --- /dev/null +++ b/src/flight/convert/constants.jl @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const DEFAULT_IPC_ALIGNMENT = 8 + +_collect_messages(messages::AbstractVector{<:Protocol.FlightData}) = messages +_collect_messages(messages) = collect(messages) diff --git a/src/flight/convert/flightdata.jl b/src/flight/convert/flightdata.jl new file mode 100644 index 0000000..2faf299 --- /dev/null +++ b/src/flight/convert/flightdata.jl @@ -0,0 +1,283 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _sourcedefaultcolmetadata(cols) + sch = Tables.schema(cols) + isnothing(sch) && return nothing + colmeta = Dict{Symbol,Any}() + Tables.eachcolumn(sch, cols) do col, _, nm + meta = ArrowParent.getmetadata(col) + isnothing(meta) || (colmeta[nm] = meta) + end + isempty(colmeta) && return nothing + return ArrowParent._normalizecolmeta(colmeta) +end + +struct FlightAppMetadataSource{T,M} + source::T + app_metadata::M +end + +ArrowParent.getmetadata(x::FlightAppMetadataSource) = ArrowParent.getmetadata(x.source) + +""" + Arrow.Flight.withappmetadata(source; app_metadata) + +Return a lightweight wrapper around `source` that carries batch-wise Flight +`app_metadata` alongside the Arrow payload. The wrapper can be passed directly +to [`Arrow.Flight.flightdata`](@ref), [`Arrow.Flight.putflightdata!`](@ref), +or source-based [`Arrow.Flight.doexchange`](@ref) without manually threading +`app_metadata=...` through each call site. +""" +withappmetadata(source; app_metadata) = + isnothing(app_metadata) ? source : FlightAppMetadataSource(source, app_metadata) + +function _unwrap_app_metadata_source(source, app_metadata) + source isa FlightAppMetadataSource || return source, app_metadata + isnothing(app_metadata) || throw( + ArgumentError( + "app_metadata cannot be provided both via Arrow.Flight.withappmetadata(...) and the app_metadata keyword", + ), + ) + return source.source, source.app_metadata +end + +_is_app_metadata_value(x) = x isa AbstractString || x isa AbstractVector{UInt8} + +function _normalize_app_metadata_value(value) + value === nothing && return UInt8[] + value isa AbstractString && return Vector{UInt8}(codeunits(value)) + value isa AbstractVector{UInt8} && return Vector{UInt8}(value) + throw( + ArgumentError( + "app_metadata entries must be AbstractString, AbstractVector{UInt8}, or nothing", + ), + ) +end + +function _normalize_app_metadata_source(app_metadata) + isnothing(app_metadata) && return nothing + return _is_app_metadata_value(app_metadata) ? (app_metadata,) : app_metadata +end + +_app_metadata_cursor(app_metadata) = + let metadata_iter = _normalize_app_metadata_source(app_metadata) + isnothing(metadata_iter) ? nothing : + (iter=metadata_iter, state=nothing, started=false) + end + +function _next_app_metadata(cursor) + isnothing(cursor) && return UInt8[], cursor + iter = cursor.iter + next = cursor.started ? iterate(iter, cursor.state) : iterate(iter) + isnothing(next) && throw( + ArgumentError("app_metadata was exhausted before all record batches were emitted"), + ) + value, state = next + return _normalize_app_metadata_value(value), (iter=iter, state=state, started=true) +end + +function _ensure_app_metadata_consumed(cursor) + isnothing(cursor) && return nothing + next = cursor.started ? iterate(cursor.iter, cursor.state) : iterate(cursor.iter) + isnothing(next) && return nothing + throw(ArgumentError("app_metadata contains more entries than source partitions")) +end + +function _partition_with_app_metadata(tbl, cursor) + app_metadata, cursor = _next_app_metadata(cursor) + return tbl, app_metadata, cursor +end + +function _emitflightdata!( + emit, + source; + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, +) + source, app_metadata = _unwrap_app_metadata_source(source, app_metadata) + dictencodings = Dict{Int64,Any}() + schema = Ref{Tables.Schema}() + normalized_colmetadata = ArrowParent._normalizecolmeta(colmetadata) + source_meta = isnothing(metadata) ? ArrowParent.getmetadata(source) : metadata + source_colmetadata = isnothing(colmetadata) ? nothing : normalized_colmetadata + app_metadata_cursor = _app_metadata_cursor(app_metadata) + + for partition in Tables.partitions(source) + tbl, record_app_metadata, app_metadata_cursor = + _partition_with_app_metadata(partition, app_metadata_cursor) + tblcols = Tables.columns(tbl) + if isnothing(metadata) + tblmeta = ArrowParent.getmetadata(tbl) + isnothing(tblmeta) && (tblmeta = source_meta) + else + tblmeta = metadata + end + if isnothing(colmetadata) + tblcolmetadata = _sourcedefaultcolmetadata(tblcols) + isnothing(tblcolmetadata) && (tblcolmetadata = source_colmetadata) + else + tblcolmetadata = normalized_colmetadata + end + cols = ArrowParent.toarrowtable( + tblcols, + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + tblmeta, + tblcolmetadata, + ) + if !isassigned(schema) + schema[] = Tables.schema(cols) + emit( + _flightdata_message( + ArrowParent.makeschemamsg(schema[], cols); + descriptor=descriptor, + alignment=alignment, + ), + ) + if !isempty(dictencodings) + for (id, delock) in sort!(collect(dictencodings); by=x -> x.first, rev=true) + de = delock.value + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + emit( + _flightdata_message( + ArrowParent.makedictionarybatchmsg( + dictsch, + (col=de.data,), + id, + false, + alignment, + ); + alignment=alignment, + ), + ) + end + end + elseif !isempty(cols.dictencodingdeltas) + for de in cols.dictencodingdeltas + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + emit( + _flightdata_message( + ArrowParent.makedictionarybatchmsg( + dictsch, + (col=de.data,), + de.id, + true, + alignment, + ); + alignment=alignment, + ), + ) + end + end + emit( + _flightdata_message( + ArrowParent.makerecordbatchmsg(schema[], cols, alignment); + app_metadata=record_app_metadata, + alignment=alignment, + ), + ) + descriptor = nothing + end + _ensure_app_metadata_consumed(app_metadata_cursor) + return nothing +end + +function flightdata( + source; + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, +) + messages = Protocol.FlightData[] + _emitflightdata!( + message -> push!(messages, message), + source; + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + app_metadata=app_metadata, + ) + return messages +end + +function putflightdata!( + sink, + source; + close::Bool=false, + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, +) + try + _emitflightdata!( + message -> put!(sink, message), + source; + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + app_metadata=app_metadata, + ) + finally + close && Base.close(sink) + end + return sink +end diff --git a/src/flight/convert/framing.jl b/src/flight/convert/framing.jl new file mode 100644 index 0000000..555219d --- /dev/null +++ b/src/flight/convert/framing.jl @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _message_body(msg::ArrowParent.Message, alignment::Integer) + msg.columns === nothing && return UInt8[] + io = IOBuffer() + for col in Tables.Columns(msg.columns) + ArrowParent.writebuffer(io, col, alignment) + end + return take!(io) +end + +function _flightdata_message( + msg::ArrowParent.Message; + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + app_metadata::AbstractVector{UInt8}=UInt8[], + alignment::Integer=DEFAULT_IPC_ALIGNMENT, +) + body = _message_body(msg, alignment) + length(body) == msg.bodylen || + throw(ArgumentError("FlightData body length mismatch while encoding Arrow IPC")) + return Protocol.FlightData( + descriptor, + Vector{UInt8}(msg.msgflatbuf), + Vector{UInt8}(app_metadata), + body, + ) +end + +function _write_framed_message( + io::IO, + data_header::AbstractVector{UInt8}, + data_body::AbstractVector{UInt8}, + alignment::Integer, +) + metalen = ArrowParent.padding(length(data_header), alignment) + Base.write(io, ArrowParent.CONTINUATION_INDICATOR_BYTES) + Base.write(io, Int32(metalen)) + Base.write(io, data_header) + ArrowParent.writezeros(io, ArrowParent.paddinglength(length(data_header), alignment)) + Base.write(io, data_body) + return +end + +function _write_end_marker(io::IO) + Base.write(io, ArrowParent.CONTINUATION_INDICATOR_BYTES) + Base.write(io, Int32(0)) + return +end diff --git a/src/flight/convert/schema.jl b/src/flight/convert/schema.jl new file mode 100644 index 0000000..55cf33d --- /dev/null +++ b/src/flight/convert/schema.jl @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _normalize_schemaipc( + schema::AbstractVector{UInt8}; + alignment::Integer=DEFAULT_IPC_ALIGNMENT, +) + bytes = Vector{UInt8}(schema) + isempty(bytes) && throw(ArgumentError("schema bytes cannot be empty")) + if length(bytes) >= 8 && + ArrowParent.readbuffer(bytes, 1, UInt32) == ArrowParent.CONTINUATION_INDICATOR_BYTES + return bytes + end + if length(bytes) >= 4 + metalen = ArrowParent.readbuffer(bytes, 1, Int32) + if metalen >= 0 && metalen == length(bytes) - 4 + io = IOBuffer() + Base.write(io, ArrowParent.CONTINUATION_INDICATOR_BYTES) + Base.write(io, bytes) + return take!(io) + end + end + io = IOBuffer() + _write_framed_message(io, bytes, UInt8[], alignment) + return take!(io) +end + +schemaipc(result::Protocol.SchemaResult; alignment::Integer=DEFAULT_IPC_ALIGNMENT) = + _normalize_schemaipc(result.schema; alignment=alignment) + +schemaipc(info::Protocol.FlightInfo; alignment::Integer=DEFAULT_IPC_ALIGNMENT) = + _normalize_schemaipc(info.schema; alignment=alignment) + +schemaipc(schema::AbstractVector{UInt8}; alignment::Integer=DEFAULT_IPC_ALIGNMENT) = + _normalize_schemaipc(schema; alignment=alignment) + +function schemaipc(message::Protocol.FlightData; alignment::Integer=DEFAULT_IPC_ALIGNMENT) + isempty(message.data_header) && + throw(ArgumentError("FlightData message is missing the Arrow IPC header")) + io = IOBuffer() + _write_framed_message(io, message.data_header, message.data_body, alignment) + return take!(io) +end + +function schemaipc(source; kwargs...) + alignment = get(kwargs, :alignment, DEFAULT_IPC_ALIGNMENT) + messages = flightdata(source; kwargs...) + isempty(messages) && + throw(ArgumentError("cannot derive schema bytes from an empty Flight source")) + return schemaipc(first(messages); alignment=alignment) +end diff --git a/src/flight/convert/streaming.jl b/src/flight/convert/streaming.jl new file mode 100644 index 0000000..06b2274 --- /dev/null +++ b/src/flight/convert/streaming.jl @@ -0,0 +1,457 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +streambytes(message::Protocol.FlightData; kwargs...) = + streambytes(Protocol.FlightData[message]; kwargs...) + +mutable struct FlightStream{M} + messages::M + state::Any + started::Bool + exhausted::Bool + nextid::Int + names::Vector{Symbol} + types::Vector{Type} + schema::Union{Nothing,ArrowParent.Meta.Schema} + dictencodings::ArrowParent.Lockable{Dict{Int64,ArrowParent.DictEncoding}} + dictencoded::Dict{Int64,ArrowParent.Meta.Field} + convert::Bool +end + +struct FlightStreamWithAppMetadata{S} + stream::S +end + +function FlightStream(messages; schema=nothing, convert::Bool=true) + x = FlightStream( + messages, + nothing, + false, + false, + 0, + Symbol[], + Type[], + nothing, + ArrowParent.Lockable(Dict{Int64,ArrowParent.DictEncoding}()), + Dict{Int64,ArrowParent.Meta.Field}(), + convert, + ) + schema === nothing || _register_schema!(x, _flight_schema(schema)) + return x +end + +Base.IteratorSize(::Type{<:FlightStream}) = Base.SizeUnknown() +Base.eltype(::Type{<:FlightStream}) = ArrowParent.Table +Base.isdone(x::FlightStream) = x.exhausted + +Base.IteratorSize(::Type{<:FlightStreamWithAppMetadata}) = Base.SizeUnknown() +Base.eltype(::Type{<:FlightStreamWithAppMetadata}) = + NamedTuple{(:table, :app_metadata),Tuple{ArrowParent.Table,Vector{UInt8}}} + +Tables.partitions(x::FlightStream) = x +Tables.partitions(x::FlightStreamWithAppMetadata) = x + +function Tables.columnnames(x::FlightStream) + _ensure_schema!(x) + return getfield(x, :names) +end + +function Tables.schema(x::FlightStream) + _ensure_schema!(x) + return Tables.Schema(Tables.columnnames(x), getfield(x, :types)) +end + +function Base.iterate(x::FlightStream) + return _iterate_flight_stream!(x) +end + +function Base.iterate(x::FlightStream, ::Nothing) + return _iterate_flight_stream!(x) +end + +function Base.iterate(x::FlightStreamWithAppMetadata) + return _iterate_flight_stream!(x.stream; include_app_metadata=true) +end + +function Base.iterate(x::FlightStreamWithAppMetadata, ::Nothing) + return _iterate_flight_stream!(x.stream; include_app_metadata=true) +end + +function _missing_schema_message() + return join( + [ + "cannot derive Arrow Flight schema from a response stream without a schema message", + "the server may have terminated the stream before emitting the first schema-bearing FlightData message", + "or the underlying transport did not surface the corresponding gRPC status", + ], + "; ", + ) +end + +function _require_schema_messages(messages::AbstractVector{<:Protocol.FlightData}, schema) + schema === nothing || return messages + any(message -> !isempty(message.data_header), messages) && return messages + throw(ArgumentError(_missing_schema_message())) +end + +function _flight_schema(schema) + schema isa ArrowParent.Meta.Schema && return schema + bytes = schemaipc(schema) + message = ArrowParent.FlatBuffers.getrootas(ArrowParent.Meta.Message, bytes, 8) + header = message.header + header isa ArrowParent.Meta.Schema || + throw(ArgumentError("Flight schema payload did not decode to an Arrow IPC schema")) + return header +end + +function _register_schema!(x::FlightStream, schema::ArrowParent.Meta.Schema) + if isnothing(getfield(x, :schema)) + setfield!(x, :schema, schema) + for field in schema.fields + ArrowParent.rejectunsupported(field) + push!(getfield(x, :names), Symbol(field.name)) + push!( + getfield(x, :types), + ArrowParent.juliaeltype( + field, + ArrowParent.buildmetadata(field.custom_metadata), + getfield(x, :convert), + ), + ) + ArrowParent.getdictionaries!(getfield(x, :dictencoded), field) + end + return x + end + schema == getfield(x, :schema) || throw( + ArgumentError( + "mismatched schemas between different arrow batches: $(getfield(x, :schema)) != $schema", + ), + ) + return x +end + +function _next_flight_message!(x::FlightStream) + getfield(x, :exhausted) && return nothing + state = + getfield(x, :started) ? iterate(getfield(x, :messages), getfield(x, :state)) : + iterate(getfield(x, :messages)) + setfield!(x, :started, true) + state === nothing && return (setfield!(x, :exhausted, true); nothing) + message, next_state = state + setfield!(x, :state, next_state) + setfield!(x, :nextid, getfield(x, :nextid) + 1) + return message +end + +function _flight_batch(message::Protocol.FlightData, id::Integer) + isempty(message.data_header) && + throw(ArgumentError("FlightData message is missing the Arrow IPC header")) + msg = + ArrowParent.FlatBuffers.getrootas(ArrowParent.Meta.Message, message.data_header, 0) + return ArrowParent.Batch(msg, message.data_body, 1, Int(id)) +end + +function _ensure_schema!(x::FlightStream) + isnothing(getfield(x, :schema)) || return x + while true + message = _next_flight_message!(x) + message === nothing && throw(ArgumentError(_missing_schema_message())) + if isempty(message.data_header) + isempty(message.data_body) || throw( + ArgumentError("FlightData message has a body but no Arrow IPC header"), + ) + continue + end + batch = _flight_batch(message, getfield(x, :nextid)) + header = batch.msg.header + if header isa ArrowParent.Meta.Schema + _register_schema!(x, header) + return x + elseif header isa ArrowParent.Meta.Tensor + throw(ArgumentError(ArrowParent.TENSOR_UNSUPPORTED)) + elseif header isa ArrowParent.Meta.SparseTensor + throw(ArgumentError(ArrowParent.SPARSE_TENSOR_UNSUPPORTED)) + end + throw(ArgumentError(_missing_schema_message())) + end +end + +function _store_dictionary_batch!( + x::FlightStream, + batch, + header::ArrowParent.Meta.DictionaryBatch, +) + id = header.id + recordbatch = header.data + @lock getfield(x, :dictencodings) begin + dictencodings = getfield(x, :dictencodings)[] + if haskey(dictencodings, id) && header.isDelta + field = getfield(x, :dictencoded)[id] + values, _, _, _ = ArrowParent.build( + field, + field.type, + batch, + recordbatch, + getfield(x, :dictencodings), + Int64(1), + Int64(1), + Int64(1), + getfield(x, :convert), + ) + dictencoding = dictencodings[id] + append!(dictencoding.data, values) + return + end + field = getfield(x, :dictencoded)[id] + values, _, _, _ = ArrowParent.build( + field, + field.type, + batch, + recordbatch, + getfield(x, :dictencodings), + Int64(1), + Int64(1), + Int64(1), + getfield(x, :convert), + ) + A = ArrowParent.ChainedVector([values]) + S = + field.dictionary.indexType === nothing ? Int32 : + ArrowParent.juliaeltype(field, field.dictionary.indexType, false) + dictencodings[id] = ArrowParent.DictEncoding{eltype(A),S,typeof(A)}( + id, + A, + field.dictionary.isOrdered, + values.metadata, + ) + end + return nothing +end + +function _flight_table(x::FlightStream, columns) + schema = getfield(x, :schema) + schema === nothing && throw(ArgumentError(_missing_schema_message())) + lookup = Dict{Symbol,AbstractVector}() + types = Type[] + for (nm, col) in zip(getfield(x, :names), columns) + lookup[nm] = col + push!(types, eltype(col)) + end + return ArrowParent.Table(getfield(x, :names), types, columns, lookup, Ref(schema)) +end + +function _empty_flight_table(x::FlightStream) + schema = getfield(x, :schema) + schema === nothing && throw(ArgumentError(_missing_schema_message())) + names = copy(getfield(x, :names)) + types = copy(getfield(x, :types)) + columns = AbstractVector[] + for field in schema.fields + T = ArrowParent.juliaeltype( + field, + ArrowParent.buildmetadata(field.custom_metadata), + getfield(x, :convert), + ) + push!(columns, T[]) + end + lookup = Dict{Symbol,AbstractVector}(names[i] => columns[i] for i in eachindex(names)) + return ArrowParent.Table(names, types, columns, lookup, Ref(schema)) +end + +function _copy_flight_table(batch::ArrowParent.Table) + names = copy(ArrowParent.names(batch)) + types = copy(ArrowParent.types(batch)) + columns = copy(ArrowParent.columns(batch)) + schema = ArrowParent.schema(batch)[] + lookup = Dict{Symbol,AbstractVector}(names[i] => columns[i] for i in eachindex(names)) + return ArrowParent.Table(names, types, columns, lookup, Ref(schema)) +end + +_copy_app_metadata(message::Protocol.FlightData) = copy(message.app_metadata) + +function _flight_batch_result( + table::ArrowParent.Table, + message::Protocol.FlightData; + include_app_metadata::Bool, +) + include_app_metadata || return table + return (table=table, app_metadata=_copy_app_metadata(message)) +end + +_flightcolumndata(col) = ArrowParent._metadatavectordata(col) + +function _chain_flight_column(col, batch_col) + metadata = ArrowParent.getmetadata(col) + chained = + ArrowParent.ChainedVector([_flightcolumndata(col), _flightcolumndata(batch_col)]) + return ArrowParent._wrapmetadata(chained, metadata) +end + +function _append_flight_column!(col, batch_col) + append!(_flightcolumndata(col), _flightcolumndata(batch_col)) + return col +end + +function _append_flight_batch!( + table::ArrowParent.Table, + batch::ArrowParent.Table, + batchindex::Int, +) + columns = ArrowParent.columns(table) + batch_columns = ArrowParent.columns(batch) + if batchindex == 2 + for i in eachindex(columns) + columns[i] = _chain_flight_column(columns[i], batch_columns[i]) + end + else + for i in eachindex(columns) + _append_flight_column!(columns[i], batch_columns[i]) + end + end + lookup = getfield(table, :lookup) + for (nm, col) in zip(ArrowParent.names(table), columns) + lookup[nm] = col + end + return table +end + +function _materialize_flight_table( + messages; + schema=nothing, + convert::Bool=true, + include_app_metadata::Bool=false, +) + stream_state = FlightStream(messages; schema=schema, convert=convert) + state = _iterate_flight_stream!(stream_state; include_app_metadata=include_app_metadata) + if state === nothing + empty_table = _empty_flight_table(stream_state) + return include_app_metadata ? (table=empty_table, app_metadata=Vector{UInt8}[]) : + empty_table + end + first_value, _ = state + first_table = include_app_metadata ? first_value.table : first_value + out = _copy_flight_table(first_table) + batch_app_metadata = + include_app_metadata ? Vector{Vector{UInt8}}([first_value.app_metadata]) : nothing + batchindex = 2 + while true + next = + _iterate_flight_stream!(stream_state; include_app_metadata=include_app_metadata) + next === nothing && break + batch_value, _ = next + batch = include_app_metadata ? batch_value.table : batch_value + _append_flight_batch!(out, batch, batchindex) + include_app_metadata && push!(batch_app_metadata, batch_value.app_metadata) + batchindex += 1 + end + return include_app_metadata ? (table=out, app_metadata=batch_app_metadata) : out +end + +function _iterate_flight_stream!(x::FlightStream; include_app_metadata::Bool=false) + _ensure_schema!(x) + while true + message = _next_flight_message!(x) + message === nothing && return nothing + if isempty(message.data_header) + isempty(message.data_body) || throw( + ArgumentError("FlightData message has a body but no Arrow IPC header"), + ) + continue + end + batch = _flight_batch(message, getfield(x, :nextid)) + header = batch.msg.header + if header isa ArrowParent.Meta.Schema + _register_schema!(x, header) + continue + elseif header isa ArrowParent.Meta.DictionaryBatch + _store_dictionary_batch!(x, batch, header) + continue + elseif header isa ArrowParent.Meta.RecordBatch + columns = collect( + ArrowParent.VectorIterator( + getfield(x, :schema), + batch, + getfield(x, :dictencodings), + getfield(x, :convert), + ), + ) + return _flight_batch_result( + _flight_table(x, columns), + message; + include_app_metadata=include_app_metadata, + ), + nothing + elseif header isa ArrowParent.Meta.Tensor + throw(ArgumentError(ArrowParent.TENSOR_UNSUPPORTED)) + elseif header isa ArrowParent.Meta.SparseTensor + throw(ArgumentError(ArrowParent.SPARSE_TENSOR_UNSUPPORTED)) + end + throw(ArgumentError("unsupported arrow message type: $(typeof(header))")) + end +end + +function streambytes( + messages; + schema=nothing, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + end_marker::Bool=true, +) + collected = _require_schema_messages(_collect_messages(messages), schema) + io = IOBuffer() + schema === nothing || Base.write(io, schemaipc(schema; alignment=alignment)) + for message in collected + if isempty(message.data_header) + isempty(message.data_body) || throw( + ArgumentError("FlightData message has a body but no Arrow IPC header"), + ) + continue + end + _write_framed_message(io, message.data_header, message.data_body, alignment) + end + end_marker && _write_end_marker(io) + return take!(io) +end + +function stream( + messages; + schema=nothing, + convert::Bool=true, + include_app_metadata::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + end_marker::Bool=true, +) + messages isa AbstractVector{<:Protocol.FlightData} && + _require_schema_messages(messages, schema) + flight_stream = FlightStream(messages; schema=schema, convert=convert) + return include_app_metadata ? FlightStreamWithAppMetadata(flight_stream) : flight_stream +end + +function table( + messages; + schema=nothing, + convert::Bool=true, + include_app_metadata::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + end_marker::Bool=true, +) + return _materialize_flight_table( + messages; + schema=schema, + convert=convert, + include_app_metadata=include_app_metadata, + ) +end diff --git a/src/flight/exports.jl b/src/flight/exports.jl new file mode 100644 index 0000000..89dc88f --- /dev/null +++ b/src/flight/exports.jl @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +export Client, + Service, + ServerCallContext, + MethodDescriptor, + ServiceDescriptor, + withheaders, + withtoken, + Protocol, + Generated, + authenticate, + callheader, + servicedescriptor, + lookupmethod, + dispatch, + handshake, + listflights, + getflightinfo, + pollflightinfo, + getschema, + doget, + doput, + doexchange, + doaction, + listactions, + schemaipc, + streambytes, + stream, + table, + withappmetadata, + flightdata, + putflightdata! diff --git a/src/flight/generated/arrow/arrow.jl b/src/flight/generated/arrow/arrow.jl new file mode 100644 index 0000000..1b821af --- /dev/null +++ b/src/flight/generated/arrow/arrow.jl @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module arrow + +include("../google/google.jl") + +include("flight/flight.jl") + +end # module arrow diff --git a/src/flight/generated/arrow/flight/flight.jl b/src/flight/generated/arrow/flight/flight.jl new file mode 100644 index 0000000..9a6b34b --- /dev/null +++ b/src/flight/generated/arrow/flight/flight.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module flight + +include("protocol/protocol.jl") + +end # module flight diff --git a/src/flight/generated/arrow/flight/protocol/Flight_pb.jl b/src/flight/generated/arrow/flight/protocol/Flight_pb.jl new file mode 100644 index 0000000..a7c73d3 --- /dev/null +++ b/src/flight/generated/arrow/flight/protocol/Flight_pb.jl @@ -0,0 +1,1359 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Autogenerated using ProtoBuf.jl v1.2.3 +# original file: Flight.proto (proto3 syntax) + +import ProtoBuf as PB +import gRPCClient +using ProtoBuf: OneOf +using ProtoBuf.EnumX: @enumx + +export HandshakeRequest, Ticket, HandshakeResponse, Action +export var"FlightDescriptor.DescriptorType", Criteria, CloseSessionRequest, Result +export ActionType, PutResult, Empty, var"SessionOptionValue.StringListValue", SchemaResult +export CancelStatus, GetSessionOptionsRequest, var"SetSessionOptionsResult.ErrorValue" +export Location, var"CloseSessionResult.Status", BasicAuth, FlightDescriptor +export SessionOptionValue, CancelFlightInfoResult, var"SetSessionOptionsResult.Error" +export FlightEndpoint, CloseSessionResult, FlightData, SetSessionOptionsRequest +export GetSessionOptionsResult, SetSessionOptionsResult, RenewFlightEndpointRequest +export FlightInfo, CancelFlightInfoRequest, PollInfo + +struct HandshakeRequest + protocol_version::UInt64 + payload::Vector{UInt8} +end +PB.default_values(::Type{HandshakeRequest}) = + (; protocol_version=zero(UInt64), payload=UInt8[]) +PB.field_numbers(::Type{HandshakeRequest}) = (; protocol_version=1, payload=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:HandshakeRequest}) + protocol_version = zero(UInt64) + payload = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + protocol_version = PB.decode(d, UInt64) + elseif field_number == 2 + payload = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return HandshakeRequest(protocol_version, payload) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::HandshakeRequest) + initpos = position(e.io) + x.protocol_version != zero(UInt64) && PB.encode(e, 1, x.protocol_version) + !isempty(x.payload) && PB.encode(e, 2, x.payload) + return position(e.io) - initpos +end +function PB._encoded_size(x::HandshakeRequest) + encoded_size = 0 + x.protocol_version != zero(UInt64) && + (encoded_size += PB._encoded_size(x.protocol_version, 1)) + !isempty(x.payload) && (encoded_size += PB._encoded_size(x.payload, 2)) + return encoded_size +end + +struct Ticket + ticket::Vector{UInt8} +end +PB.default_values(::Type{Ticket}) = (; ticket=UInt8[]) +PB.field_numbers(::Type{Ticket}) = (; ticket=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Ticket}) + ticket = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + ticket = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Ticket(ticket) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Ticket) + initpos = position(e.io) + !isempty(x.ticket) && PB.encode(e, 1, x.ticket) + return position(e.io) - initpos +end +function PB._encoded_size(x::Ticket) + encoded_size = 0 + !isempty(x.ticket) && (encoded_size += PB._encoded_size(x.ticket, 1)) + return encoded_size +end + +struct HandshakeResponse + protocol_version::UInt64 + payload::Vector{UInt8} +end +PB.default_values(::Type{HandshakeResponse}) = + (; protocol_version=zero(UInt64), payload=UInt8[]) +PB.field_numbers(::Type{HandshakeResponse}) = (; protocol_version=1, payload=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:HandshakeResponse}) + protocol_version = zero(UInt64) + payload = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + protocol_version = PB.decode(d, UInt64) + elseif field_number == 2 + payload = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return HandshakeResponse(protocol_version, payload) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::HandshakeResponse) + initpos = position(e.io) + x.protocol_version != zero(UInt64) && PB.encode(e, 1, x.protocol_version) + !isempty(x.payload) && PB.encode(e, 2, x.payload) + return position(e.io) - initpos +end +function PB._encoded_size(x::HandshakeResponse) + encoded_size = 0 + x.protocol_version != zero(UInt64) && + (encoded_size += PB._encoded_size(x.protocol_version, 1)) + !isempty(x.payload) && (encoded_size += PB._encoded_size(x.payload, 2)) + return encoded_size +end + +struct Action + var"#type"::String + body::Vector{UInt8} +end +PB.default_values(::Type{Action}) = (; var"#type"="", body=UInt8[]) +PB.field_numbers(::Type{Action}) = (; var"#type"=1, body=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Action}) + var"#type" = "" + body = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + var"#type" = PB.decode(d, String) + elseif field_number == 2 + body = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Action(var"#type", body) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Action) + initpos = position(e.io) + !isempty(x.var"#type") && PB.encode(e, 1, x.var"#type") + !isempty(x.body) && PB.encode(e, 2, x.body) + return position(e.io) - initpos +end +function PB._encoded_size(x::Action) + encoded_size = 0 + !isempty(x.var"#type") && (encoded_size += PB._encoded_size(x.var"#type", 1)) + !isempty(x.body) && (encoded_size += PB._encoded_size(x.body, 2)) + return encoded_size +end + +@enumx var"FlightDescriptor.DescriptorType" UNKNOWN=0 PATH=1 CMD=2 + +struct Criteria + expression::Vector{UInt8} +end +PB.default_values(::Type{Criteria}) = (; expression=UInt8[]) +PB.field_numbers(::Type{Criteria}) = (; expression=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Criteria}) + expression = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + expression = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Criteria(expression) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Criteria) + initpos = position(e.io) + !isempty(x.expression) && PB.encode(e, 1, x.expression) + return position(e.io) - initpos +end +function PB._encoded_size(x::Criteria) + encoded_size = 0 + !isempty(x.expression) && (encoded_size += PB._encoded_size(x.expression, 1)) + return encoded_size +end + +struct CloseSessionRequest end + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CloseSessionRequest}) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + Base.skip(d, wire_type) + end + return CloseSessionRequest() +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CloseSessionRequest) + initpos = position(e.io) + return position(e.io) - initpos +end +function PB._encoded_size(x::CloseSessionRequest) + encoded_size = 0 + return encoded_size +end + +struct Result + body::Vector{UInt8} +end +PB.default_values(::Type{Result}) = (; body=UInt8[]) +PB.field_numbers(::Type{Result}) = (; body=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Result}) + body = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + body = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Result(body) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Result) + initpos = position(e.io) + !isempty(x.body) && PB.encode(e, 1, x.body) + return position(e.io) - initpos +end +function PB._encoded_size(x::Result) + encoded_size = 0 + !isempty(x.body) && (encoded_size += PB._encoded_size(x.body, 1)) + return encoded_size +end + +struct ActionType + var"#type"::String + description::String +end +PB.default_values(::Type{ActionType}) = (; var"#type"="", description="") +PB.field_numbers(::Type{ActionType}) = (; var"#type"=1, description=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:ActionType}) + var"#type" = "" + description = "" + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + var"#type" = PB.decode(d, String) + elseif field_number == 2 + description = PB.decode(d, String) + else + Base.skip(d, wire_type) + end + end + return ActionType(var"#type", description) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::ActionType) + initpos = position(e.io) + !isempty(x.var"#type") && PB.encode(e, 1, x.var"#type") + !isempty(x.description) && PB.encode(e, 2, x.description) + return position(e.io) - initpos +end +function PB._encoded_size(x::ActionType) + encoded_size = 0 + !isempty(x.var"#type") && (encoded_size += PB._encoded_size(x.var"#type", 1)) + !isempty(x.description) && (encoded_size += PB._encoded_size(x.description, 2)) + return encoded_size +end + +struct PutResult + app_metadata::Vector{UInt8} +end +PB.default_values(::Type{PutResult}) = (; app_metadata=UInt8[]) +PB.field_numbers(::Type{PutResult}) = (; app_metadata=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:PutResult}) + app_metadata = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + app_metadata = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return PutResult(app_metadata) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::PutResult) + initpos = position(e.io) + !isempty(x.app_metadata) && PB.encode(e, 1, x.app_metadata) + return position(e.io) - initpos +end +function PB._encoded_size(x::PutResult) + encoded_size = 0 + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 1)) + return encoded_size +end + +struct Empty end + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Empty}) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + Base.skip(d, wire_type) + end + return Empty() +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Empty) + initpos = position(e.io) + return position(e.io) - initpos +end +function PB._encoded_size(x::Empty) + encoded_size = 0 + return encoded_size +end + +struct var"SessionOptionValue.StringListValue" + values::Vector{String} +end +PB.default_values(::Type{var"SessionOptionValue.StringListValue"}) = + (; values=Vector{String}()) +PB.field_numbers(::Type{var"SessionOptionValue.StringListValue"}) = (; values=1) + +function PB.decode( + d::PB.AbstractProtoDecoder, + ::Type{<:var"SessionOptionValue.StringListValue"}, +) + values = PB.BufferedVector{String}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, values) + else + Base.skip(d, wire_type) + end + end + return var"SessionOptionValue.StringListValue"(values[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::var"SessionOptionValue.StringListValue") + initpos = position(e.io) + !isempty(x.values) && PB.encode(e, 1, x.values) + return position(e.io) - initpos +end +function PB._encoded_size(x::var"SessionOptionValue.StringListValue") + encoded_size = 0 + !isempty(x.values) && (encoded_size += PB._encoded_size(x.values, 1)) + return encoded_size +end + +struct SchemaResult + schema::Vector{UInt8} +end +PB.default_values(::Type{SchemaResult}) = (; schema=UInt8[]) +PB.field_numbers(::Type{SchemaResult}) = (; schema=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SchemaResult}) + schema = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + schema = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return SchemaResult(schema) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SchemaResult) + initpos = position(e.io) + !isempty(x.schema) && PB.encode(e, 1, x.schema) + return position(e.io) - initpos +end +function PB._encoded_size(x::SchemaResult) + encoded_size = 0 + !isempty(x.schema) && (encoded_size += PB._encoded_size(x.schema, 1)) + return encoded_size +end + +@enumx CancelStatus CANCEL_STATUS_UNSPECIFIED=0 CANCEL_STATUS_CANCELLED=1 CANCEL_STATUS_CANCELLING=2 CANCEL_STATUS_NOT_CANCELLABLE=3 + +struct GetSessionOptionsRequest end + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:GetSessionOptionsRequest}) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + Base.skip(d, wire_type) + end + return GetSessionOptionsRequest() +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::GetSessionOptionsRequest) + initpos = position(e.io) + return position(e.io) - initpos +end +function PB._encoded_size(x::GetSessionOptionsRequest) + encoded_size = 0 + return encoded_size +end + +@enumx var"SetSessionOptionsResult.ErrorValue" UNSPECIFIED=0 INVALID_NAME=1 INVALID_VALUE=2 ERROR=3 + +struct Location + uri::String +end +PB.default_values(::Type{Location}) = (; uri="") +PB.field_numbers(::Type{Location}) = (; uri=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Location}) + uri = "" + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + uri = PB.decode(d, String) + else + Base.skip(d, wire_type) + end + end + return Location(uri) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Location) + initpos = position(e.io) + !isempty(x.uri) && PB.encode(e, 1, x.uri) + return position(e.io) - initpos +end +function PB._encoded_size(x::Location) + encoded_size = 0 + !isempty(x.uri) && (encoded_size += PB._encoded_size(x.uri, 1)) + return encoded_size +end + +@enumx var"CloseSessionResult.Status" UNSPECIFIED=0 CLOSED=1 CLOSING=2 NOT_CLOSEABLE=3 + +struct BasicAuth + username::String + password::String +end +PB.default_values(::Type{BasicAuth}) = (; username="", password="") +PB.field_numbers(::Type{BasicAuth}) = (; username=2, password=3) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:BasicAuth}) + username = "" + password = "" + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 2 + username = PB.decode(d, String) + elseif field_number == 3 + password = PB.decode(d, String) + else + Base.skip(d, wire_type) + end + end + return BasicAuth(username, password) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::BasicAuth) + initpos = position(e.io) + !isempty(x.username) && PB.encode(e, 2, x.username) + !isempty(x.password) && PB.encode(e, 3, x.password) + return position(e.io) - initpos +end +function PB._encoded_size(x::BasicAuth) + encoded_size = 0 + !isempty(x.username) && (encoded_size += PB._encoded_size(x.username, 2)) + !isempty(x.password) && (encoded_size += PB._encoded_size(x.password, 3)) + return encoded_size +end + +struct FlightDescriptor + var"#type"::var"FlightDescriptor.DescriptorType".T + cmd::Vector{UInt8} + path::Vector{String} +end +PB.default_values(::Type{FlightDescriptor}) = (; + var"#type"=var"FlightDescriptor.DescriptorType".UNKNOWN, + cmd=UInt8[], + path=Vector{String}(), +) +PB.field_numbers(::Type{FlightDescriptor}) = (; var"#type"=1, cmd=2, path=3) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightDescriptor}) + var"#type" = var"FlightDescriptor.DescriptorType".UNKNOWN + cmd = UInt8[] + path = PB.BufferedVector{String}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + var"#type" = PB.decode(d, var"FlightDescriptor.DescriptorType".T) + elseif field_number == 2 + cmd = PB.decode(d, Vector{UInt8}) + elseif field_number == 3 + PB.decode!(d, path) + else + Base.skip(d, wire_type) + end + end + return FlightDescriptor(var"#type", cmd, path[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightDescriptor) + initpos = position(e.io) + x.var"#type" != var"FlightDescriptor.DescriptorType".UNKNOWN && + PB.encode(e, 1, x.var"#type") + !isempty(x.cmd) && PB.encode(e, 2, x.cmd) + !isempty(x.path) && PB.encode(e, 3, x.path) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightDescriptor) + encoded_size = 0 + x.var"#type" != var"FlightDescriptor.DescriptorType".UNKNOWN && + (encoded_size += PB._encoded_size(x.var"#type", 1)) + !isempty(x.cmd) && (encoded_size += PB._encoded_size(x.cmd, 2)) + !isempty(x.path) && (encoded_size += PB._encoded_size(x.path, 3)) + return encoded_size +end + +struct SessionOptionValue + option_value::Union{ + Nothing, + OneOf{<:Union{String,Bool,Int64,Float64,var"SessionOptionValue.StringListValue"}}, + } +end +PB.oneof_field_types(::Type{SessionOptionValue}) = (; + option_value=(; + string_value=String, + bool_value=Bool, + int64_value=Int64, + double_value=Float64, + string_list_value=var"SessionOptionValue.StringListValue", + ), +) +PB.default_values(::Type{SessionOptionValue}) = (; + string_value="", + bool_value=false, + int64_value=zero(Int64), + double_value=zero(Float64), + string_list_value=nothing, +) +PB.field_numbers(::Type{SessionOptionValue}) = + (; string_value=1, bool_value=2, int64_value=3, double_value=4, string_list_value=5) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SessionOptionValue}) + option_value = nothing + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + option_value = OneOf(:string_value, PB.decode(d, String)) + elseif field_number == 2 + option_value = OneOf(:bool_value, PB.decode(d, Bool)) + elseif field_number == 3 + option_value = OneOf(:int64_value, PB.decode(d, Int64, Val{:fixed})) + elseif field_number == 4 + option_value = OneOf(:double_value, PB.decode(d, Float64)) + elseif field_number == 5 + option_value = OneOf( + :string_list_value, + PB.decode(d, Ref{var"SessionOptionValue.StringListValue"}), + ) + else + Base.skip(d, wire_type) + end + end + return SessionOptionValue(option_value) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SessionOptionValue) + initpos = position(e.io) + if isnothing(x.option_value) + ; + elseif x.option_value.name === :string_value + PB.encode(e, 1, x.option_value[]::String) + elseif x.option_value.name === :bool_value + PB.encode(e, 2, x.option_value[]::Bool) + elseif x.option_value.name === :int64_value + PB.encode(e, 3, x.option_value[]::Int64, Val{:fixed}) + elseif x.option_value.name === :double_value + PB.encode(e, 4, x.option_value[]::Float64) + elseif x.option_value.name === :string_list_value + PB.encode(e, 5, x.option_value[]::var"SessionOptionValue.StringListValue") + end + return position(e.io) - initpos +end +function PB._encoded_size(x::SessionOptionValue) + encoded_size = 0 + if isnothing(x.option_value) + ; + elseif x.option_value.name === :string_value + encoded_size += PB._encoded_size(x.option_value[]::String, 1) + elseif x.option_value.name === :bool_value + encoded_size += PB._encoded_size(x.option_value[]::Bool, 2) + elseif x.option_value.name === :int64_value + encoded_size += PB._encoded_size(x.option_value[]::Int64, 3, Val{:fixed}) + elseif x.option_value.name === :double_value + encoded_size += PB._encoded_size(x.option_value[]::Float64, 4) + elseif x.option_value.name === :string_list_value + encoded_size += + PB._encoded_size(x.option_value[]::var"SessionOptionValue.StringListValue", 5) + end + return encoded_size +end + +struct CancelFlightInfoResult + status::CancelStatus.T +end +PB.default_values(::Type{CancelFlightInfoResult}) = + (; status=CancelStatus.CANCEL_STATUS_UNSPECIFIED) +PB.field_numbers(::Type{CancelFlightInfoResult}) = (; status=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CancelFlightInfoResult}) + status = CancelStatus.CANCEL_STATUS_UNSPECIFIED + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + status = PB.decode(d, CancelStatus.T) + else + Base.skip(d, wire_type) + end + end + return CancelFlightInfoResult(status) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CancelFlightInfoResult) + initpos = position(e.io) + x.status != CancelStatus.CANCEL_STATUS_UNSPECIFIED && PB.encode(e, 1, x.status) + return position(e.io) - initpos +end +function PB._encoded_size(x::CancelFlightInfoResult) + encoded_size = 0 + x.status != CancelStatus.CANCEL_STATUS_UNSPECIFIED && + (encoded_size += PB._encoded_size(x.status, 1)) + return encoded_size +end + +struct var"SetSessionOptionsResult.Error" + value::var"SetSessionOptionsResult.ErrorValue".T +end +PB.default_values(::Type{var"SetSessionOptionsResult.Error"}) = + (; value=var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED) +PB.field_numbers(::Type{var"SetSessionOptionsResult.Error"}) = (; value=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:var"SetSessionOptionsResult.Error"}) + value = var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + value = PB.decode(d, var"SetSessionOptionsResult.ErrorValue".T) + else + Base.skip(d, wire_type) + end + end + return var"SetSessionOptionsResult.Error"(value) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::var"SetSessionOptionsResult.Error") + initpos = position(e.io) + x.value != var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED && + PB.encode(e, 1, x.value) + return position(e.io) - initpos +end +function PB._encoded_size(x::var"SetSessionOptionsResult.Error") + encoded_size = 0 + x.value != var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED && + (encoded_size += PB._encoded_size(x.value, 1)) + return encoded_size +end + +struct FlightEndpoint + ticket::Union{Nothing,Ticket} + location::Vector{Location} + expiration_time::Union{Nothing,google.protobuf.Timestamp} + app_metadata::Vector{UInt8} +end +PB.default_values(::Type{FlightEndpoint}) = (; + ticket=nothing, + location=Vector{Location}(), + expiration_time=nothing, + app_metadata=UInt8[], +) +PB.field_numbers(::Type{FlightEndpoint}) = + (; ticket=1, location=2, expiration_time=3, app_metadata=4) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightEndpoint}) + ticket = Ref{Union{Nothing,Ticket}}(nothing) + location = PB.BufferedVector{Location}() + expiration_time = Ref{Union{Nothing,google.protobuf.Timestamp}}(nothing) + app_metadata = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, ticket) + elseif field_number == 2 + PB.decode!(d, location) + elseif field_number == 3 + PB.decode!(d, expiration_time) + elseif field_number == 4 + app_metadata = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return FlightEndpoint(ticket[], location[], expiration_time[], app_metadata) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightEndpoint) + initpos = position(e.io) + !isnothing(x.ticket) && PB.encode(e, 1, x.ticket) + !isempty(x.location) && PB.encode(e, 2, x.location) + !isnothing(x.expiration_time) && PB.encode(e, 3, x.expiration_time) + !isempty(x.app_metadata) && PB.encode(e, 4, x.app_metadata) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightEndpoint) + encoded_size = 0 + !isnothing(x.ticket) && (encoded_size += PB._encoded_size(x.ticket, 1)) + !isempty(x.location) && (encoded_size += PB._encoded_size(x.location, 2)) + !isnothing(x.expiration_time) && + (encoded_size += PB._encoded_size(x.expiration_time, 3)) + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 4)) + return encoded_size +end + +struct CloseSessionResult + status::var"CloseSessionResult.Status".T +end +PB.default_values(::Type{CloseSessionResult}) = + (; status=var"CloseSessionResult.Status".UNSPECIFIED) +PB.field_numbers(::Type{CloseSessionResult}) = (; status=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CloseSessionResult}) + status = var"CloseSessionResult.Status".UNSPECIFIED + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + status = PB.decode(d, var"CloseSessionResult.Status".T) + else + Base.skip(d, wire_type) + end + end + return CloseSessionResult(status) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CloseSessionResult) + initpos = position(e.io) + x.status != var"CloseSessionResult.Status".UNSPECIFIED && PB.encode(e, 1, x.status) + return position(e.io) - initpos +end +function PB._encoded_size(x::CloseSessionResult) + encoded_size = 0 + x.status != var"CloseSessionResult.Status".UNSPECIFIED && + (encoded_size += PB._encoded_size(x.status, 1)) + return encoded_size +end + +struct FlightData + flight_descriptor::Union{Nothing,FlightDescriptor} + data_header::Vector{UInt8} + app_metadata::Vector{UInt8} + data_body::Vector{UInt8} +end +PB.default_values(::Type{FlightData}) = (; + flight_descriptor=nothing, + data_header=UInt8[], + app_metadata=UInt8[], + data_body=UInt8[], +) +PB.field_numbers(::Type{FlightData}) = + (; flight_descriptor=1, data_header=2, app_metadata=3, data_body=1000) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightData}) + flight_descriptor = Ref{Union{Nothing,FlightDescriptor}}(nothing) + data_header = UInt8[] + app_metadata = UInt8[] + data_body = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, flight_descriptor) + elseif field_number == 2 + data_header = PB.decode(d, Vector{UInt8}) + elseif field_number == 3 + app_metadata = PB.decode(d, Vector{UInt8}) + elseif field_number == 1000 + data_body = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return FlightData(flight_descriptor[], data_header, app_metadata, data_body) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightData) + initpos = position(e.io) + !isnothing(x.flight_descriptor) && PB.encode(e, 1, x.flight_descriptor) + !isempty(x.data_header) && PB.encode(e, 2, x.data_header) + !isempty(x.app_metadata) && PB.encode(e, 3, x.app_metadata) + !isempty(x.data_body) && PB.encode(e, 1000, x.data_body) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightData) + encoded_size = 0 + !isnothing(x.flight_descriptor) && + (encoded_size += PB._encoded_size(x.flight_descriptor, 1)) + !isempty(x.data_header) && (encoded_size += PB._encoded_size(x.data_header, 2)) + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 3)) + !isempty(x.data_body) && (encoded_size += PB._encoded_size(x.data_body, 1000)) + return encoded_size +end + +struct SetSessionOptionsRequest + session_options::Dict{String,SessionOptionValue} +end +PB.default_values(::Type{SetSessionOptionsRequest}) = + (; session_options=Dict{String,SessionOptionValue}()) +PB.field_numbers(::Type{SetSessionOptionsRequest}) = (; session_options=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SetSessionOptionsRequest}) + session_options = Dict{String,SessionOptionValue}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, session_options) + else + Base.skip(d, wire_type) + end + end + return SetSessionOptionsRequest(session_options) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SetSessionOptionsRequest) + initpos = position(e.io) + !isempty(x.session_options) && PB.encode(e, 1, x.session_options) + return position(e.io) - initpos +end +function PB._encoded_size(x::SetSessionOptionsRequest) + encoded_size = 0 + !isempty(x.session_options) && (encoded_size += PB._encoded_size(x.session_options, 1)) + return encoded_size +end + +struct GetSessionOptionsResult + session_options::Dict{String,SessionOptionValue} +end +PB.default_values(::Type{GetSessionOptionsResult}) = + (; session_options=Dict{String,SessionOptionValue}()) +PB.field_numbers(::Type{GetSessionOptionsResult}) = (; session_options=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:GetSessionOptionsResult}) + session_options = Dict{String,SessionOptionValue}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, session_options) + else + Base.skip(d, wire_type) + end + end + return GetSessionOptionsResult(session_options) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::GetSessionOptionsResult) + initpos = position(e.io) + !isempty(x.session_options) && PB.encode(e, 1, x.session_options) + return position(e.io) - initpos +end +function PB._encoded_size(x::GetSessionOptionsResult) + encoded_size = 0 + !isempty(x.session_options) && (encoded_size += PB._encoded_size(x.session_options, 1)) + return encoded_size +end + +struct SetSessionOptionsResult + errors::Dict{String,var"SetSessionOptionsResult.Error"} +end +PB.default_values(::Type{SetSessionOptionsResult}) = + (; errors=Dict{String,var"SetSessionOptionsResult.Error"}()) +PB.field_numbers(::Type{SetSessionOptionsResult}) = (; errors=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SetSessionOptionsResult}) + errors = Dict{String,var"SetSessionOptionsResult.Error"}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, errors) + else + Base.skip(d, wire_type) + end + end + return SetSessionOptionsResult(errors) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SetSessionOptionsResult) + initpos = position(e.io) + !isempty(x.errors) && PB.encode(e, 1, x.errors) + return position(e.io) - initpos +end +function PB._encoded_size(x::SetSessionOptionsResult) + encoded_size = 0 + !isempty(x.errors) && (encoded_size += PB._encoded_size(x.errors, 1)) + return encoded_size +end + +struct RenewFlightEndpointRequest + endpoint::Union{Nothing,FlightEndpoint} +end +PB.default_values(::Type{RenewFlightEndpointRequest}) = (; endpoint=nothing) +PB.field_numbers(::Type{RenewFlightEndpointRequest}) = (; endpoint=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:RenewFlightEndpointRequest}) + endpoint = Ref{Union{Nothing,FlightEndpoint}}(nothing) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, endpoint) + else + Base.skip(d, wire_type) + end + end + return RenewFlightEndpointRequest(endpoint[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::RenewFlightEndpointRequest) + initpos = position(e.io) + !isnothing(x.endpoint) && PB.encode(e, 1, x.endpoint) + return position(e.io) - initpos +end +function PB._encoded_size(x::RenewFlightEndpointRequest) + encoded_size = 0 + !isnothing(x.endpoint) && (encoded_size += PB._encoded_size(x.endpoint, 1)) + return encoded_size +end + +struct FlightInfo + schema::Vector{UInt8} + flight_descriptor::Union{Nothing,FlightDescriptor} + endpoint::Vector{FlightEndpoint} + total_records::Int64 + total_bytes::Int64 + ordered::Bool + app_metadata::Vector{UInt8} +end +PB.default_values(::Type{FlightInfo}) = (; + schema=UInt8[], + flight_descriptor=nothing, + endpoint=Vector{FlightEndpoint}(), + total_records=zero(Int64), + total_bytes=zero(Int64), + ordered=false, + app_metadata=UInt8[], +) +PB.field_numbers(::Type{FlightInfo}) = (; + schema=1, + flight_descriptor=2, + endpoint=3, + total_records=4, + total_bytes=5, + ordered=6, + app_metadata=7, +) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightInfo}) + schema = UInt8[] + flight_descriptor = Ref{Union{Nothing,FlightDescriptor}}(nothing) + endpoint = PB.BufferedVector{FlightEndpoint}() + total_records = zero(Int64) + total_bytes = zero(Int64) + ordered = false + app_metadata = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + schema = PB.decode(d, Vector{UInt8}) + elseif field_number == 2 + PB.decode!(d, flight_descriptor) + elseif field_number == 3 + PB.decode!(d, endpoint) + elseif field_number == 4 + total_records = PB.decode(d, Int64) + elseif field_number == 5 + total_bytes = PB.decode(d, Int64) + elseif field_number == 6 + ordered = PB.decode(d, Bool) + elseif field_number == 7 + app_metadata = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return FlightInfo( + schema, + flight_descriptor[], + endpoint[], + total_records, + total_bytes, + ordered, + app_metadata, + ) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightInfo) + initpos = position(e.io) + !isempty(x.schema) && PB.encode(e, 1, x.schema) + !isnothing(x.flight_descriptor) && PB.encode(e, 2, x.flight_descriptor) + !isempty(x.endpoint) && PB.encode(e, 3, x.endpoint) + x.total_records != zero(Int64) && PB.encode(e, 4, x.total_records) + x.total_bytes != zero(Int64) && PB.encode(e, 5, x.total_bytes) + x.ordered != false && PB.encode(e, 6, x.ordered) + !isempty(x.app_metadata) && PB.encode(e, 7, x.app_metadata) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightInfo) + encoded_size = 0 + !isempty(x.schema) && (encoded_size += PB._encoded_size(x.schema, 1)) + !isnothing(x.flight_descriptor) && + (encoded_size += PB._encoded_size(x.flight_descriptor, 2)) + !isempty(x.endpoint) && (encoded_size += PB._encoded_size(x.endpoint, 3)) + x.total_records != zero(Int64) && (encoded_size += PB._encoded_size(x.total_records, 4)) + x.total_bytes != zero(Int64) && (encoded_size += PB._encoded_size(x.total_bytes, 5)) + x.ordered != false && (encoded_size += PB._encoded_size(x.ordered, 6)) + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 7)) + return encoded_size +end + +struct CancelFlightInfoRequest + info::Union{Nothing,FlightInfo} +end +PB.default_values(::Type{CancelFlightInfoRequest}) = (; info=nothing) +PB.field_numbers(::Type{CancelFlightInfoRequest}) = (; info=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CancelFlightInfoRequest}) + info = Ref{Union{Nothing,FlightInfo}}(nothing) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, info) + else + Base.skip(d, wire_type) + end + end + return CancelFlightInfoRequest(info[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CancelFlightInfoRequest) + initpos = position(e.io) + !isnothing(x.info) && PB.encode(e, 1, x.info) + return position(e.io) - initpos +end +function PB._encoded_size(x::CancelFlightInfoRequest) + encoded_size = 0 + !isnothing(x.info) && (encoded_size += PB._encoded_size(x.info, 1)) + return encoded_size +end + +struct PollInfo + info::Union{Nothing,FlightInfo} + flight_descriptor::Union{Nothing,FlightDescriptor} + progress::Float64 + expiration_time::Union{Nothing,google.protobuf.Timestamp} +end +PB.default_values(::Type{PollInfo}) = (; + info=nothing, + flight_descriptor=nothing, + progress=zero(Float64), + expiration_time=nothing, +) +PB.field_numbers(::Type{PollInfo}) = + (; info=1, flight_descriptor=2, progress=3, expiration_time=4) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:PollInfo}) + info = Ref{Union{Nothing,FlightInfo}}(nothing) + flight_descriptor = Ref{Union{Nothing,FlightDescriptor}}(nothing) + progress = zero(Float64) + expiration_time = Ref{Union{Nothing,google.protobuf.Timestamp}}(nothing) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, info) + elseif field_number == 2 + PB.decode!(d, flight_descriptor) + elseif field_number == 3 + progress = PB.decode(d, Float64) + elseif field_number == 4 + PB.decode!(d, expiration_time) + else + Base.skip(d, wire_type) + end + end + return PollInfo(info[], flight_descriptor[], progress, expiration_time[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::PollInfo) + initpos = position(e.io) + !isnothing(x.info) && PB.encode(e, 1, x.info) + !isnothing(x.flight_descriptor) && PB.encode(e, 2, x.flight_descriptor) + x.progress !== zero(Float64) && PB.encode(e, 3, x.progress) + !isnothing(x.expiration_time) && PB.encode(e, 4, x.expiration_time) + return position(e.io) - initpos +end +function PB._encoded_size(x::PollInfo) + encoded_size = 0 + !isnothing(x.info) && (encoded_size += PB._encoded_size(x.info, 1)) + !isnothing(x.flight_descriptor) && + (encoded_size += PB._encoded_size(x.flight_descriptor, 2)) + x.progress !== zero(Float64) && (encoded_size += PB._encoded_size(x.progress, 3)) + !isnothing(x.expiration_time) && + (encoded_size += PB._encoded_size(x.expiration_time, 4)) + return encoded_size +end + +# gRPCClient.jl BEGIN +FlightService_Handshake_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{HandshakeRequest,true,HandshakeResponse,true}( + host, + port, + "/arrow.flight.protocol.FlightService/Handshake"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_Handshake_Client + +FlightService_ListFlights_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Criteria,false,FlightInfo,true}( + host, + port, + "/arrow.flight.protocol.FlightService/ListFlights"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_ListFlights_Client + +FlightService_GetFlightInfo_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightDescriptor,false,FlightInfo,false}( + host, + port, + "/arrow.flight.protocol.FlightService/GetFlightInfo"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_GetFlightInfo_Client + +FlightService_PollFlightInfo_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightDescriptor,false,PollInfo,false}( + host, + port, + "/arrow.flight.protocol.FlightService/PollFlightInfo"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_PollFlightInfo_Client + +FlightService_GetSchema_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightDescriptor,false,SchemaResult,false}( + host, + port, + "/arrow.flight.protocol.FlightService/GetSchema"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_GetSchema_Client + +FlightService_DoGet_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Ticket,false,FlightData,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoGet"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoGet_Client + +FlightService_DoPut_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightData,true,PutResult,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoPut"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoPut_Client + +FlightService_DoExchange_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightData,true,FlightData,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoExchange"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoExchange_Client + +FlightService_DoAction_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Action,false,Result,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoAction"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoAction_Client + +FlightService_ListActions_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Empty,false,ActionType,true}( + host, + port, + "/arrow.flight.protocol.FlightService/ListActions"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_ListActions_Client +# gRPCClient.jl END diff --git a/src/flight/generated/arrow/flight/protocol/protocol.jl b/src/flight/generated/arrow/flight/protocol/protocol.jl new file mode 100644 index 0000000..0e8132f --- /dev/null +++ b/src/flight/generated/arrow/flight/protocol/protocol.jl @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module protocol + +import ...google + +include("Flight_pb.jl") + +end # module protocol diff --git a/src/flight/generated/google/google.jl b/src/flight/generated/google/google.jl new file mode 100644 index 0000000..eaea425 --- /dev/null +++ b/src/flight/generated/google/google.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module google + +include("protobuf/protobuf.jl") + +end # module google diff --git a/src/flight/generated/google/protobuf/protobuf.jl b/src/flight/generated/google/protobuf/protobuf.jl new file mode 100644 index 0000000..f066b99 --- /dev/null +++ b/src/flight/generated/google/protobuf/protobuf.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module protobuf + +include("timestamp_pb.jl") + +end # module protobuf diff --git a/src/flight/generated/google/protobuf/timestamp_pb.jl b/src/flight/generated/google/protobuf/timestamp_pb.jl new file mode 100644 index 0000000..2831ff7 --- /dev/null +++ b/src/flight/generated/google/protobuf/timestamp_pb.jl @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Autogenerated using ProtoBuf.jl v1.2.3 +# original file: google/protobuf/timestamp.proto (proto3 syntax) + +import ProtoBuf as PB +using ProtoBuf: OneOf +using ProtoBuf.EnumX: @enumx + +export Timestamp + +struct Timestamp + seconds::Int64 + nanos::Int32 +end +PB.default_values(::Type{Timestamp}) = (; seconds=zero(Int64), nanos=zero(Int32)) +PB.field_numbers(::Type{Timestamp}) = (; seconds=1, nanos=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Timestamp}) + seconds = zero(Int64) + nanos = zero(Int32) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + seconds = PB.decode(d, Int64) + elseif field_number == 2 + nanos = PB.decode(d, Int32) + else + Base.skip(d, wire_type) + end + end + return Timestamp(seconds, nanos) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Timestamp) + initpos = position(e.io) + x.seconds != zero(Int64) && PB.encode(e, 1, x.seconds) + x.nanos != zero(Int32) && PB.encode(e, 2, x.nanos) + return position(e.io) - initpos +end +function PB._encoded_size(x::Timestamp) + encoded_size = 0 + x.seconds != zero(Int64) && (encoded_size += PB._encoded_size(x.seconds, 1)) + x.nanos != zero(Int32) && (encoded_size += PB._encoded_size(x.nanos, 2)) + return encoded_size +end diff --git a/src/flight/proto/Flight.proto b/src/flight/proto/Flight.proto new file mode 100644 index 0000000..69e74c5 --- /dev/null +++ b/src/flight/proto/Flight.proto @@ -0,0 +1,678 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +import "google/protobuf/timestamp.proto"; + +option java_package = "org.apache.arrow.flight.impl"; +option go_package = "github.com/apache/arrow-go/arrow/flight/gen/flight"; +option csharp_namespace = "Apache.Arrow.Flight.Protocol"; + +package arrow.flight.protocol; + +/* + * A flight service is an endpoint for retrieving or storing Arrow data. A + * flight service can expose one or more predefined endpoints that can be + * accessed using the Arrow Flight Protocol. Additionally, a flight service + * can expose a set of actions that are available. + */ +service FlightService { + + /* + * Handshake between client and server. Depending on the server, the + * handshake may be required to determine the token that should be used for + * future operations. Both request and response are streams to allow multiple + * round-trips depending on auth mechanism. + */ + rpc Handshake(stream HandshakeRequest) returns (stream HandshakeResponse) {} + + /* + * Get a list of available streams given a particular criteria. Most flight + * services will expose one or more streams that are readily available for + * retrieval. This api allows listing the streams available for + * consumption. A user can also provide a criteria. The criteria can limit + * the subset of streams that can be listed via this interface. Each flight + * service allows its own definition of how to consume criteria. + */ + rpc ListFlights(Criteria) returns (stream FlightInfo) {} + + /* + * For a given FlightDescriptor, get information about how the flight can be + * consumed. This is a useful interface if the consumer of the interface + * already can identify the specific flight to consume. This interface can + * also allow a consumer to generate a flight stream through a specified + * descriptor. For example, a flight descriptor might be something that + * includes a SQL statement or a Pickled Python operation that will be + * executed. In those cases, the descriptor will not be previously available + * within the list of available streams provided by ListFlights but will be + * available for consumption for the duration defined by the specific flight + * service. + */ + rpc GetFlightInfo(FlightDescriptor) returns (FlightInfo) {} + + /* + * For a given FlightDescriptor, start a query and get information + * to poll its execution status. This is a useful interface if the + * query may be a long-running query. The first PollFlightInfo call + * should return as quickly as possible. (GetFlightInfo doesn't + * return until the query is complete.) + * + * A client can consume any available results before + * the query is completed. See PollInfo.info for details. + * + * A client can poll the updated query status by calling + * PollFlightInfo() with PollInfo.flight_descriptor. A server + * should not respond until the result would be different from last + * time. That way, the client can "long poll" for updates + * without constantly making requests. Clients can set a short timeout + * to avoid blocking calls if desired. + * + * A client can't use PollInfo.flight_descriptor after + * PollInfo.expiration_time passes. A server might not accept the + * retry descriptor anymore and the query may be cancelled. + * + * A client may use the CancelFlightInfo action with + * PollInfo.info to cancel the running query. + */ + rpc PollFlightInfo(FlightDescriptor) returns (PollInfo) {} + + /* + * For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema + * This is used when a consumer needs the Schema of flight stream. Similar to + * GetFlightInfo this interface may generate a new flight that was not previously + * available in ListFlights. + */ + rpc GetSchema(FlightDescriptor) returns (SchemaResult) {} + + /* + * Retrieve a single stream associated with a particular descriptor + * associated with the referenced ticket. A Flight can be composed of one or + * more streams where each stream can be retrieved using a separate opaque + * ticket that the flight service uses for managing a collection of streams. + */ + rpc DoGet(Ticket) returns (stream FlightData) {} + + /* + * Push a stream to the flight service associated with a particular + * flight stream. This allows a client of a flight service to upload a stream + * of data. Depending on the particular flight service, a client consumer + * could be allowed to upload a single stream per descriptor or an unlimited + * number. In the latter, the service might implement a 'seal' action that + * can be applied to a descriptor once all streams are uploaded. + */ + rpc DoPut(stream FlightData) returns (stream PutResult) {} + + /* + * Open a bidirectional data channel for a given descriptor. This + * allows clients to send and receive arbitrary Arrow data and + * application-specific metadata in a single logical stream. In + * contrast to DoGet/DoPut, this is more suited for clients + * offloading computation (rather than storage) to a Flight service. + */ + rpc DoExchange(stream FlightData) returns (stream FlightData) {} + + /* + * Flight services can support an arbitrary number of simple actions in + * addition to the possible ListFlights, GetFlightInfo, DoGet, DoPut + * operations that are potentially available. DoAction allows a flight client + * to do a specific action against a flight service. An action includes + * opaque request and response objects that are specific to the type action + * being undertaken. + */ + rpc DoAction(Action) returns (stream Result) {} + + /* + * A flight service exposes all of the available action types that it has + * along with descriptions. This allows different flight consumers to + * understand the capabilities of the flight service. + */ + rpc ListActions(Empty) returns (stream ActionType) {} +} + +/* + * The request that a client provides to a server on handshake. + */ +message HandshakeRequest { + + /* + * A defined protocol version + */ + uint64 protocol_version = 1; + + /* + * Arbitrary auth/handshake info. + */ + bytes payload = 2; +} + +message HandshakeResponse { + + /* + * A defined protocol version + */ + uint64 protocol_version = 1; + + /* + * Arbitrary auth/handshake info. + */ + bytes payload = 2; +} + +/* + * A message for doing simple auth. + */ +message BasicAuth { + string username = 2; + string password = 3; +} + +message Empty {} + +/* + * Describes an available action, including both the name used for execution + * along with a short description of the purpose of the action. + */ +message ActionType { + string type = 1; + string description = 2; +} + +/* + * A service specific expression that can be used to return a limited set + * of available Arrow Flight streams. + */ +message Criteria { + bytes expression = 1; +} + +/* + * An opaque action specific for the service. + */ +message Action { + string type = 1; + bytes body = 2; +} + +/* + * An opaque result returned after executing an action. + */ +message Result { + bytes body = 1; +} + +/* + * Wrap the result of a getSchema call + */ +message SchemaResult { + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + bytes schema = 1; +} + +/* + * The name or tag for a Flight. May be used as a way to retrieve or generate + * a flight or be used to expose a set of previously defined flights. + */ +message FlightDescriptor { + + /* + * Describes what type of descriptor is defined. + */ + enum DescriptorType { + + // Protobuf pattern, not used. + UNKNOWN = 0; + + /* + * A named path that identifies a dataset. A path is composed of a string + * or list of strings describing a particular dataset. This is conceptually + * similar to a path inside a filesystem. + */ + PATH = 1; + + /* + * An opaque command to generate a dataset. + */ + CMD = 2; + } + + DescriptorType type = 1; + + /* + * Opaque value used to express a command. Should only be defined when + * type = CMD. + */ + bytes cmd = 2; + + /* + * List of strings identifying a particular dataset. Should only be defined + * when type = PATH. + */ + repeated string path = 3; +} + +/* + * The access coordinates for retrieval of a dataset. With a FlightInfo, a + * consumer is able to determine how to retrieve a dataset. + */ +message FlightInfo { + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + bytes schema = 1; + + /* + * The descriptor associated with this info. + */ + FlightDescriptor flight_descriptor = 2; + + /* + * A list of endpoints associated with the flight. To consume the + * whole flight, all endpoints (and hence all Tickets) must be + * consumed. Endpoints can be consumed in any order. + * + * In other words, an application can use multiple endpoints to + * represent partitioned data. + * + * If the returned data has an ordering, an application can use + * "FlightInfo.ordered = true" or should return all data in a + * single endpoint. Otherwise, there is no ordering defined on + * endpoints or the data within. + * + * A client can read ordered data by reading data from returned + * endpoints, in order, from front to back. + * + * Note that a client may ignore "FlightInfo.ordered = true". If an + * ordering is important for an application, an application must + * choose one of them: + * + * * An application requires that all clients must read data in + * returned endpoints order. + * * An application must return all data in a single endpoint. + */ + repeated FlightEndpoint endpoint = 3; + + // Set these to -1 if unknown. + int64 total_records = 4; + int64 total_bytes = 5; + + /* + * FlightEndpoints are in the same order as the data. + */ + bool ordered = 6; + + /* + * Application-defined metadata. + * + * There is no inherent or required relationship between this + * and the app_metadata fields in the FlightEndpoints or resulting + * FlightData messages. Since this metadata is application-defined, + * a given application could define there to be a relationship, + * but there is none required by the spec. + */ + bytes app_metadata = 7; +} + +/* + * The information to process a long-running query. + */ +message PollInfo { + /* + * The currently available results. + * + * If "flight_descriptor" is not specified, the query is complete + * and "info" specifies all results. Otherwise, "info" contains + * partial query results. + * + * Note that each PollInfo response contains a complete + * FlightInfo (not just the delta between the previous and current + * FlightInfo). + * + * Subsequent PollInfo responses may only append new endpoints to + * info. + * + * Clients can begin fetching results via DoGet(Ticket) with the + * ticket in the info before the query is + * completed. FlightInfo.ordered is also valid. + */ + FlightInfo info = 1; + + /* + * The descriptor the client should use on the next try. + * If unset, the query is complete. + */ + FlightDescriptor flight_descriptor = 2; + + /* + * Query progress. If known, must be in [0.0, 1.0] but need not be + * monotonic or nondecreasing. If unknown, do not set. + */ + optional double progress = 3; + + /* + * Expiration time for this request. After this passes, the server + * might not accept the retry descriptor anymore (and the query may + * be cancelled). This may be updated on a call to PollFlightInfo. + */ + google.protobuf.Timestamp expiration_time = 4; +} + +/* + * The request of the CancelFlightInfo action. + * + * The request should be stored in Action.body. + */ +message CancelFlightInfoRequest { + FlightInfo info = 1; +} + +/* + * The result of a cancel operation. + * + * This is used by CancelFlightInfoResult.status. + */ +enum CancelStatus { + // The cancellation status is unknown. Servers should avoid using + // this value (send a NOT_FOUND error if the requested query is + // not known). Clients can retry the request. + CANCEL_STATUS_UNSPECIFIED = 0; + // The cancellation request is complete. Subsequent requests with + // the same payload may return CANCELLED or a NOT_FOUND error. + CANCEL_STATUS_CANCELLED = 1; + // The cancellation request is in progress. The client may retry + // the cancellation request. + CANCEL_STATUS_CANCELLING = 2; + // The query is not cancellable. The client should not retry the + // cancellation request. + CANCEL_STATUS_NOT_CANCELLABLE = 3; +} + +/* + * The result of the CancelFlightInfo action. + * + * The result should be stored in Result.body. + */ +message CancelFlightInfoResult { + CancelStatus status = 1; +} + +/* + * An opaque identifier that the service can use to retrieve a particular + * portion of a stream. + * + * Tickets are meant to be single use. It is an error/application-defined + * behavior to reuse a ticket. + */ +message Ticket { + bytes ticket = 1; +} + +/* + * A location to retrieve a particular stream from. This URI should be one of + * the following: + * - An empty string or the string 'arrow-flight-reuse-connection://?': + * indicating that the ticket can be redeemed on the service where the + * ticket was generated via a DoGet request. + * - A valid grpc URI (grpc://, grpc+tls://, grpc+unix://, etc.): + * indicating that the ticket can be redeemed on the service at the given + * URI via a DoGet request. + * - A valid HTTP URI (http://, https://, etc.): + * indicating that the client should perform a GET request against the + * given URI to retrieve the stream. The ticket should be empty + * in this case and should be ignored by the client. Cloud object storage + * can be utilized by presigned URLs or mediating the auth separately and + * returning the full URL (e.g. https://amzn-s3-demo-bucket.s3.us-west-2.amazonaws.com/...). + * + * We allow non-Flight URIs for the purpose of allowing Flight services to indicate that + * results can be downloaded in formats other than Arrow (such as Parquet) or to allow + * direct fetching of results from a URI to reduce excess copying and data movement. + * In these cases, the following conventions should be followed by servers and clients: + * + * - Unless otherwise specified by the 'Content-Type' header of the response, + * a client should assume the response is using the Arrow IPC Streaming format. + * Usage of an IANA media type like 'application/octet-stream' should be assumed to + * be using the Arrow IPC Streaming format. + * - The server may allow the client to choose a specific response format by + * specifying an 'Accept' header in the request, such as 'application/vnd.apache.parquet' + * or 'application/vnd.apache.arrow.stream'. If multiple types are requested and + * supported by the server, the choice of which to use is server-specific. If + * none of the requested content-types are supported, the server may respond with + * either 406 (Not Acceptable) or 415 (Unsupported Media Type), or successfully + * respond with a different format that it does support along with the correct + * 'Content-Type' header. + * + * Note: new schemes may be proposed in the future to allow for more flexibility based + * on community requests. + */ +message Location { + string uri = 1; +} + +/* + * A particular stream or split associated with a flight. + */ +message FlightEndpoint { + + /* + * Token used to retrieve this stream. + */ + Ticket ticket = 1; + + /* + * A list of URIs where this ticket can be redeemed via DoGet(). + * + * If the list is empty, the expectation is that the ticket can only + * be redeemed on the current service where the ticket was + * generated. + * + * If the list is not empty, the expectation is that the ticket can be + * redeemed at any of the locations, and that the data returned will be + * equivalent. In this case, the ticket may only be redeemed at one of the + * given locations, and not (necessarily) on the current service. If one + * of the given locations is "arrow-flight-reuse-connection://?", the + * client may redeem the ticket on the service where the ticket was + * generated (i.e., the same as above), in addition to the other + * locations. (This URI was chosen to maximize compatibility, as 'scheme:' + * or 'scheme://' are not accepted by Java's java.net.URI.) + * + * In other words, an application can use multiple locations to + * represent redundant and/or load balanced services. + */ + repeated Location location = 2; + + /* + * Expiration time of this stream. If present, clients may assume + * they can retry DoGet requests. Otherwise, it is + * application-defined whether DoGet requests may be retried. + */ + google.protobuf.Timestamp expiration_time = 3; + + /* + * Application-defined metadata. + * + * There is no inherent or required relationship between this + * and the app_metadata fields in the FlightInfo or resulting + * FlightData messages. Since this metadata is application-defined, + * a given application could define there to be a relationship, + * but there is none required by the spec. + */ + bytes app_metadata = 4; +} + +/* + * The request of the RenewFlightEndpoint action. + * + * The request should be stored in Action.body. + */ +message RenewFlightEndpointRequest { + FlightEndpoint endpoint = 1; +} + +/* + * A batch of Arrow data as part of a stream of batches. + */ +message FlightData { + + /* + * The descriptor of the data. This is only relevant when a client is + * starting a new DoPut stream. + */ + FlightDescriptor flight_descriptor = 1; + + /* + * Header for message data as described in Message.fbs::Message. + */ + bytes data_header = 2; + + /* + * Application-defined metadata. + */ + bytes app_metadata = 3; + + /* + * The actual batch of Arrow data. Preferably handled with minimal-copies + * coming last in the definition to help with sidecar patterns (it is + * expected that some implementations will fetch this field off the wire + * with specialized code to avoid extra memory copies). + */ + bytes data_body = 1000; +} + +/** + * The response message associated with the submission of a DoPut. + */ +message PutResult { + bytes app_metadata = 1; +} + +/* + * EXPERIMENTAL: Union of possible value types for a Session Option to be set to. + * + * By convention, an attempt to set a valueless SessionOptionValue should + * attempt to unset or clear the named option value on the server. + */ +message SessionOptionValue { + message StringListValue { + repeated string values = 1; + } + + oneof option_value { + string string_value = 1; + bool bool_value = 2; + sfixed64 int64_value = 3; + double double_value = 4; + StringListValue string_list_value = 5; + } +} + +/* + * EXPERIMENTAL: A request to set session options for an existing or new (implicit) + * server session. + * + * Sessions are persisted and referenced via a transport-level state management, typically + * RFC 6265 HTTP cookies when using an HTTP transport. The suggested cookie name or state + * context key is 'arrow_flight_session_id', although implementations may freely choose their + * own name. + * + * Session creation (if one does not already exist) is implied by this RPC request, however + * server implementations may choose to initiate a session that also contains client-provided + * session options at any other time, e.g. on authentication, or when any other call is made + * and the server wishes to use a session to persist any state (or lack thereof). + */ +message SetSessionOptionsRequest { + map session_options = 1; +} + +/* + * EXPERIMENTAL: The results (individually) of setting a set of session options. + * + * Option names should only be present in the response if they were not successfully + * set on the server; that is, a response without an Error for a name provided in the + * SetSessionOptionsRequest implies that the named option value was set successfully. + */ +message SetSessionOptionsResult { + enum ErrorValue { + // Protobuf deserialization fallback value: The status is unknown or unrecognized. + // Servers should avoid using this value. The request may be retried by the client. + UNSPECIFIED = 0; + // The given session option name is invalid. + INVALID_NAME = 1; + // The session option value or type is invalid. + INVALID_VALUE = 2; + // The session option cannot be set. + ERROR = 3; + } + + message Error { + ErrorValue value = 1; + } + + map errors = 1; +} + +/* + * EXPERIMENTAL: A request to access the session options for the current server session. + * + * The existing session is referenced via a cookie header or similar (see + * SetSessionOptionsRequest above); it is an error to make this request with a missing, + * invalid, or expired session cookie header or other implementation-defined session + * reference token. + */ +message GetSessionOptionsRequest { +} + +/* + * EXPERIMENTAL: The result containing the current server session options. + */ +message GetSessionOptionsResult { + map session_options = 1; +} + +/* + * Request message for the "Close Session" action. + * + * The existing session is referenced via a cookie header. + */ +message CloseSessionRequest { +} + +/* + * The result of closing a session. + */ +message CloseSessionResult { + enum Status { + // Protobuf deserialization fallback value: The session close status is unknown or + // not recognized. Servers should avoid using this value (send a NOT_FOUND error if + // the requested session is not known or expired). Clients can retry the request. + UNSPECIFIED = 0; + // The session close request is complete. Subsequent requests with + // the same session produce a NOT_FOUND error. + CLOSED = 1; + // The session close request is in progress. The client may retry + // the close request. + CLOSING = 2; + // The session is not closeable. The client should not retry the + // close request. + NOT_CLOSEABLE = 3; + } + + Status status = 1; +} diff --git a/src/flight/protocol.jl b/src/flight/protocol.jl new file mode 100644 index 0000000..a1f74c0 --- /dev/null +++ b/src/flight/protocol.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module Generated +include("generated/arrow/arrow.jl") +end + +const Protocol = Generated.arrow.flight.protocol diff --git a/src/flight/server.jl b/src/flight/server.jl new file mode 100644 index 0000000..e7f1212 --- /dev/null +++ b/src/flight/server.jl @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("server/types.jl") +include("server/descriptors.jl") +include("server/handlers.jl") +include("server/dispatch.jl") diff --git a/src/flight/server/descriptors.jl b/src/flight/server/descriptors.jl new file mode 100644 index 0000000..913b609 --- /dev/null +++ b/src/flight/server/descriptors.jl @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const FLIGHT_SERVICE_NAME = "arrow.flight.protocol.FlightService" + +struct MethodDescriptor + name::String + path::String + handler_field::Symbol + request_streaming::Bool + response_streaming::Bool + request_type::Type + response_type::Type +end + +struct ServiceDescriptor + name::String + methods::Vector{MethodDescriptor} + method_lookup::Dict{String,MethodDescriptor} +end + +function MethodDescriptor( + name::AbstractString, + handler_field::Symbol, + request_streaming::Bool, + response_streaming::Bool, + request_type::Type, + response_type::Type, +) + normalized_name = String(name) + MethodDescriptor( + normalized_name, + "/$(FLIGHT_SERVICE_NAME)/$(normalized_name)", + handler_field, + request_streaming, + response_streaming, + request_type, + response_type, + ) +end + +function ServiceDescriptor(name::AbstractString, methods::Vector{MethodDescriptor}) + lookup = Dict{String,MethodDescriptor}() + for method in methods + lookup[method.name] = method + lookup[method.path] = method + end + return ServiceDescriptor(String(name), methods, lookup) +end + +const FLIGHT_METHODS = [ + MethodDescriptor( + "Handshake", + :handshake, + true, + true, + Protocol.HandshakeRequest, + Protocol.HandshakeResponse, + ), + MethodDescriptor( + "ListFlights", + :listflights, + false, + true, + Protocol.Criteria, + Protocol.FlightInfo, + ), + MethodDescriptor( + "GetFlightInfo", + :getflightinfo, + false, + false, + Protocol.FlightDescriptor, + Protocol.FlightInfo, + ), + MethodDescriptor( + "PollFlightInfo", + :pollflightinfo, + false, + false, + Protocol.FlightDescriptor, + Protocol.PollInfo, + ), + MethodDescriptor( + "GetSchema", + :getschema, + false, + false, + Protocol.FlightDescriptor, + Protocol.SchemaResult, + ), + MethodDescriptor("DoGet", :doget, false, true, Protocol.Ticket, Protocol.FlightData), + MethodDescriptor("DoPut", :doput, true, true, Protocol.FlightData, Protocol.PutResult), + MethodDescriptor( + "DoExchange", + :doexchange, + true, + true, + Protocol.FlightData, + Protocol.FlightData, + ), + MethodDescriptor("DoAction", :doaction, false, true, Protocol.Action, Protocol.Result), + MethodDescriptor( + "ListActions", + :listactions, + false, + true, + Protocol.Empty, + Protocol.ActionType, + ), +] + +const FLIGHT_SERVICE_DESCRIPTOR = ServiceDescriptor(FLIGHT_SERVICE_NAME, FLIGHT_METHODS) + +servicedescriptor(::Service) = FLIGHT_SERVICE_DESCRIPTOR + +function lookupmethod(descriptor::ServiceDescriptor, key::AbstractString) + return get(descriptor.method_lookup, String(key), nothing) +end + +lookupmethod(service::Service, key::AbstractString) = + lookupmethod(servicedescriptor(service), key) diff --git a/src/flight/server/dispatch.jl b/src/flight/server/dispatch.jl new file mode 100644 index 0000000..2f3241b --- /dev/null +++ b/src/flight/server/dispatch.jl @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function dispatch( + service::Service, + context::ServerCallContext, + method::MethodDescriptor, + args..., +) + if method.handler_field === :handshake + return handshake(service, context, args...) + elseif method.handler_field === :listflights + return listflights(service, context, args...) + elseif method.handler_field === :getflightinfo + return getflightinfo(service, context, args...) + elseif method.handler_field === :pollflightinfo + return pollflightinfo(service, context, args...) + elseif method.handler_field === :getschema + return getschema(service, context, args...) + elseif method.handler_field === :doget + return doget(service, context, args...) + elseif method.handler_field === :doput + return doput(service, context, args...) + elseif method.handler_field === :doexchange + return doexchange(service, context, args...) + elseif method.handler_field === :doaction + return doaction(service, context, args...) + elseif method.handler_field === :listactions + return listactions(service, context, args...) + end + + throw(ArgumentError("unsupported Arrow Flight handler field $(method.handler_field)")) +end + +function dispatch( + service::Service, + context::ServerCallContext, + key::AbstractString, + args..., +) + method = lookupmethod(service, key) + isnothing(method) && + throw(ArgumentError("unknown Arrow Flight method path or name: $(String(key))")) + return dispatch(service, context, method, args...) +end diff --git a/src/flight/server/handlers.jl b/src/flight/server/handlers.jl new file mode 100644 index 0000000..c0be8a7 --- /dev/null +++ b/src/flight/server/handlers.jl @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _unimplemented(service_method::String) + throw( + gRPCClient.gRPCServiceCallException( + gRPCClient.GRPC_UNIMPLEMENTED, + "Arrow Flight server method $(service_method) is not implemented", + ), + ) +end + +function _invoke_handler(handler::Union{Nothing,Function}, service_method::String, args...) + isnothing(handler) && _unimplemented(service_method) + return handler(args...) +end + +handshake( + service::Service, + context::ServerCallContext, + request::Channel{Protocol.HandshakeRequest}, + response::Channel{Protocol.HandshakeResponse}, +) = _invoke_handler(service.handshake, "Handshake", context, request, response) + +listflights( + service::Service, + context::ServerCallContext, + criteria::Protocol.Criteria, + response::Channel{Protocol.FlightInfo}, +) = _invoke_handler(service.listflights, "ListFlights", context, criteria, response) + +getflightinfo( + service::Service, + context::ServerCallContext, + descriptor::Protocol.FlightDescriptor, +) = _invoke_handler(service.getflightinfo, "GetFlightInfo", context, descriptor) + +pollflightinfo( + service::Service, + context::ServerCallContext, + descriptor::Protocol.FlightDescriptor, +) = _invoke_handler(service.pollflightinfo, "PollFlightInfo", context, descriptor) + +getschema( + service::Service, + context::ServerCallContext, + descriptor::Protocol.FlightDescriptor, +) = _invoke_handler(service.getschema, "GetSchema", context, descriptor) + +doget( + service::Service, + context::ServerCallContext, + ticket::Protocol.Ticket, + response::Channel{Protocol.FlightData}, +) = _invoke_handler(service.doget, "DoGet", context, ticket, response) + +doput( + service::Service, + context::ServerCallContext, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.PutResult}, +) = _invoke_handler(service.doput, "DoPut", context, request, response) + +doexchange( + service::Service, + context::ServerCallContext, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.FlightData}, +) = _invoke_handler(service.doexchange, "DoExchange", context, request, response) + +doaction( + service::Service, + context::ServerCallContext, + action::Protocol.Action, + response::Channel{Protocol.Result}, +) = _invoke_handler(service.doaction, "DoAction", context, action, response) + +listactions( + service::Service, + context::ServerCallContext, + response::Channel{Protocol.ActionType}, +) = _invoke_handler(service.listactions, "ListActions", context, response) diff --git a/src/flight/server/types.jl b/src/flight/server/types.jl new file mode 100644 index 0000000..71b9872 --- /dev/null +++ b/src/flight/server/types.jl @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const ServerHeaderPair = HeaderPair + +Base.@kwdef struct ServerCallContext + headers::Vector{ServerHeaderPair} = ServerHeaderPair[] + peer::Union{Nothing,String} = nothing + secure::Bool = false +end + +Base.@kwdef struct Service + handshake::Union{Nothing,Function} = nothing + listflights::Union{Nothing,Function} = nothing + getflightinfo::Union{Nothing,Function} = nothing + pollflightinfo::Union{Nothing,Function} = nothing + getschema::Union{Nothing,Function} = nothing + doget::Union{Nothing,Function} = nothing + doput::Union{Nothing,Function} = nothing + doexchange::Union{Nothing,Function} = nothing + doaction::Union{Nothing,Function} = nothing + listactions::Union{Nothing,Function} = nothing +end + +function callheader(context::ServerCallContext, name::AbstractString) + needle = lowercase(String(name)) + for (header_name, header_value) in context.headers + lowercase(header_name) == needle && return header_value + end + return nothing +end diff --git a/src/logicaltypes.jl b/src/logicaltypes.jl new file mode 100644 index 0000000..7695692 --- /dev/null +++ b/src/logicaltypes.jl @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const EXTENSION_NAME_KEY = "ARROW:extension:name" +const EXTENSION_METADATA_KEY = "ARROW:extension:metadata" + +struct ExtensionTypeSpec + name::Symbol + metadata::String +end + +@inline _extensiontypename(spec::ExtensionTypeSpec) = String(spec.name) +@inline _builtinextensionspec(::Type{T}) where {T} = nothing +@inline _builtinextensionjuliatype(::Val{name}, storageT) where {name} = + _builtinextensionjuliatype(Val(name), storageT, "") +@inline _builtinextensionjuliatype(::Val{name}, storageT, metadata) where {name} = nothing +@inline _builtinarrowtype(::Type{T}) where {T} = nothing +@inline _builtintoarrow(x) = nothing +@inline _builtinarrowname(::Type{T}) where {T} = nothing +function _builtinfromarrow end +function _builtinfromarrowstruct end +function _builtindefault end +function _builtinopaquemetadata end +function _builtinvariantmetadata end +function _builtinfixedshapetensormetadata end +function _builtinvariableshapetensormetadata end +@inline _validatebuiltinextension(::Val{name}, field, metadata) where {name} = nothing + +@inline function _extensionmetadatafor(::Type{T}, meta) where {T} + spec = _extensionspec(T) + spec === nothing && return meta + return _mergeextensionmeta(meta, spec) +end + +@inline function _extensionspec(::Type{T}) where {T} + spec = _builtinextensionspec(T) + spec !== nothing && return spec + ArrowTypes.hasarrowname(T) || return nothing + return ExtensionTypeSpec(ArrowTypes.arrowname(T), String(ArrowTypes.arrowmetadata(T))) +end + +@inline function _extensionspec(meta::AbstractDict{String,String}) + haskey(meta, EXTENSION_NAME_KEY) || return nothing + return ExtensionTypeSpec( + Symbol(meta[EXTENSION_NAME_KEY]), + get(meta, EXTENSION_METADATA_KEY, ""), + ) +end + +function _mergeextensionmeta(::Nothing, spec::ExtensionTypeSpec) + return toidict(( + EXTENSION_NAME_KEY => _extensiontypename(spec), + EXTENSION_METADATA_KEY => spec.metadata, + ),) +end + +function _mergeextensionmeta(::Nothing, name::Symbol, metadata::String) + return toidict((EXTENSION_NAME_KEY => String(name), EXTENSION_METADATA_KEY => metadata)) +end + +function _mergeextensionmeta(meta, spec::ExtensionTypeSpec) + dict = Dict(meta) + dict[EXTENSION_NAME_KEY] = _extensiontypename(spec) + dict[EXTENSION_METADATA_KEY] = spec.metadata + return toidict(dict) +end + +function _mergeextensionmeta(meta, name::Symbol, metadata::String) + dict = Dict(meta) + dict[EXTENSION_NAME_KEY] = String(name) + dict[EXTENSION_METADATA_KEY] = metadata + return toidict(dict) +end + +@inline function _builtinextensionjuliatype(spec::ExtensionTypeSpec, storageT) + return _builtinextensionjuliatype(Val(spec.name), storageT, spec.metadata) +end + +@inline function _resolveextensionjuliatype(spec::ExtensionTypeSpec, storageT) + builtin = _builtinextensionjuliatype(spec, storageT) + builtin !== nothing && return builtin + return ArrowTypes.JuliaType(Val(spec.name), storageT, spec.metadata) +end + +@inline function _validatebuiltinextension(spec::ExtensionTypeSpec, field::Meta.Field) + return _validatebuiltinextension(Val(spec.name), field, spec.metadata) +end diff --git a/src/logicaltypes_builtin.jl b/src/logicaltypes_builtin.jl new file mode 100644 index 0000000..428811e --- /dev/null +++ b/src/logicaltypes_builtin.jl @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_builtinarrowtype(::Type{ArrowTypes.UUID}) = NTuple{16,UInt8} +_builtintoarrow(x::ArrowTypes.UUID) = ArrowTypes._cast(NTuple{16,UInt8}, x.value) +_builtinarrowname(::Type{ArrowTypes.UUID}) = ArrowTypes.UUIDSYMBOL +_builtinextensionspec(::Type{ArrowTypes.UUID}) = + ExtensionTypeSpec(_builtinarrowname(ArrowTypes.UUID), "") +_builtinextensionjuliatype(::Val{ArrowTypes.UUIDSYMBOL}, S, metadata::String) = + ArrowTypes.UUID +_builtinextensionjuliatype(::Val{ArrowTypes.LEGACY_UUIDSYMBOL}, S, metadata::String) = + ArrowTypes.UUID + +_builtinextensionspec(::Type{Bool8}) = ExtensionTypeSpec(BOOL8_SYMBOL, "") +_builtinarrowtype(::Type{Bool8}) = Int8 +_builtintoarrow(x::Bool8) = Int8(Bool(x)) +_builtinarrowname(::Type{Bool8}) = BOOL8_SYMBOL +_builtinextensionjuliatype(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 +_builtinfromarrow(::Type{Bool8}, x::Int8) = Bool8(x) +_builtindefault(::Type{Bool8}) = zero(Bool8) + +_builtinextensionspec(::Type{JSONText{S}}) where {S<:AbstractString} = + ExtensionTypeSpec(JSON_SYMBOL, "") +_builtinarrowtype(::Type{JSONText{S}}) where {S<:AbstractString} = S +_builtintoarrow(x::JSONText) = getfield(x, :value) +_builtinarrowname(::Type{JSONText{S}}) where {S<:AbstractString} = JSON_SYMBOL +_builtinextensionjuliatype( + ::Val{JSON_SYMBOL}, + ::Type{S}, + metadata::String, +) where {S<:AbstractString} = JSONText{S} +_builtinfromarrow(::Type{JSONText{String}}, ptr::Ptr{UInt8}, len::Int) = + JSONText(unsafe_string(ptr, len)) +_builtinfromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = JSONText{S}(x) +_builtindefault(::Type{JSONText{S}}) where {S<:AbstractString} = + JSONText{S}(ArrowTypes.default(S)) + +_builtinextensionjuliatype(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +_builtinopaquemetadata(type_name::AbstractString, vendor_name::AbstractString) = + "{\"type_name\":" * + _jsonstringliteral(type_name) * + ",\"vendor_name\":" * + _jsonstringliteral(vendor_name) * + "}" +_builtinvariantmetadata() = "" + +function _builtinfixedshapetensormetadata( + shape::AbstractVector{<:Integer}; + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + parsed_shape = _parseintvector(FIXED_SHAPE_TENSOR_SYMBOL, collect(shape), "shape") + parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) + parsed_permutation = + permutation === nothing ? nothing : + _validatepermutation( + FIXED_SHAPE_TENSOR_SYMBOL, + Int.(permutation), + length(parsed_shape), + ) + parsed_dim_names !== nothing && length(parsed_dim_names) == length(parsed_shape) || + isnothing(parsed_dim_names) || + _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "\"dim_names\" must have length $(length(parsed_shape))", + ) + body = Dict{String,Any}("shape" => parsed_shape) + parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) + parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) + return JSON3.write(body) +end + +function _builtinvariableshapetensormetadata(; + uniform_shape::Union{Nothing,AbstractVector}=nothing, + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + uniform = + uniform_shape === nothing ? nothing : + _parseintvector( + VARIABLE_SHAPE_TENSOR_SYMBOL, + collect(uniform_shape), + "uniform_shape"; + allow_null=true, + ) + ndim = uniform === nothing ? nothing : length(uniform) + parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) + parsed_permutation = permutation === nothing ? nothing : Int.(permutation) + ndim !== nothing && parsed_dim_names !== nothing && length(parsed_dim_names) == ndim || + ndim === nothing || + isnothing(parsed_dim_names) || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"dim_names\" must have length $ndim", + ) + ndim !== nothing && + parsed_permutation !== nothing && + _validatepermutation(VARIABLE_SHAPE_TENSOR_SYMBOL, parsed_permutation, ndim) + body = Dict{String,Any}() + uniform !== nothing && (body["uniform_shape"] = uniform) + parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) + parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) + return isempty(body) ? "" : JSON3.write(body) +end +_validatebuiltinextension( + ::Val{PARQUET_VARIANT_SYMBOL}, + field::Meta.Field, + metadata::String, +) = _validateparquetvariant(field, metadata) +_validatebuiltinextension( + ::Val{FIXED_SHAPE_TENSOR_SYMBOL}, + field::Meta.Field, + metadata::String, +) = _validatefixedshapetensor(field, metadata) +_validatebuiltinextension( + ::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, + field::Meta.Field, + metadata::String, +) = _validatevariableshapetensor(field, metadata) + +_builtinextensionspec(::Type{ZonedDateTime}) = ExtensionTypeSpec(ZONEDDATETIME_SYMBOL, "") +_builtinarrowtype(::Type{ZonedDateTime}) = Timestamp +_builtintoarrow(x::ZonedDateTime) = + convert(Timestamp{Meta.TimeUnit.MILLISECOND,Symbol(x.timezone)}, x) +_builtinarrowname(::Type{ZonedDateTime}) = ZONEDDATETIME_SYMBOL +_builtinextensionjuliatype(::Val{ZONEDDATETIME_SYMBOL}, S, metadata::String) = ZonedDateTime +_builtinfromarrow(::Type{ZonedDateTime}, x::Timestamp) = convert(ZonedDateTime, x) +_builtindefault(::Type{TimeZones.ZonedDateTime}) = + TimeZones.ZonedDateTime(1, 1, 1, 1, 1, 1, TimeZones.tz"UTC") + +_builtinextensionspec(::Type{TimestampWithOffset{U}}) where {U} = + ExtensionTypeSpec(TIMESTAMP_WITH_OFFSET_SYMBOL, "") +_builtinarrowtype(::Type{TimestampWithOffset{U}}) where {U} = + NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}} +_builtintoarrow(x::TimestampWithOffset{U}) where {U} = + (timestamp=getfield(x, :timestamp), offset_minutes=getfield(x, :offset_minutes)) +_builtinarrowname(::Type{TimestampWithOffset{U}}) where {U} = TIMESTAMP_WITH_OFFSET_SYMBOL +_builtinextensionjuliatype( + ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, + ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, + metadata::String, +) where {U} = TimestampWithOffset{U} +_builtindefault(::Type{TimestampWithOffset{U}}) where {U} = zero(TimestampWithOffset{U}) +_builtinfromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:timestamp, :offset_minutes)}, + timestamp::Timestamp{U,:UTC}, + offset_minutes::Int16, +) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) +_builtinfromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:offset_minutes, :timestamp)}, + offset_minutes::Int16, + timestamp::Timestamp{U,:UTC}, +) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) + +_builtinextensionjuliatype(::Val{OLD_ZONEDDATETIME_SYMBOL}, S, metadata::String) = + LocalZonedDateTime +function _builtinfromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} + (U === Meta.TimeUnit.MICROSECOND || U == Meta.TimeUnit.NANOSECOND) && + warntimestamp(U, ZonedDateTime) + return ZonedDateTime( + Dates.DateTime( + Dates.UTM(Int64(Dates.toms(periodtype(U)(x.x)) + UNIX_EPOCH_DATETIME)), + ), + TimeZone(String(TZ)), + ) +end diff --git a/src/metadata/Message.jl b/src/metadata/Message.jl index 0e49439..5649191 100644 --- a/src/metadata/Message.jl +++ b/src/metadata/Message.jl @@ -157,12 +157,28 @@ dictionaryBatchAddIsDelta(b::FlatBuffers.Builder, isdelta::Base.Bool) = FlatBuffers.prependslot!(b, 2, isdelta, false) dictionaryBatchEnd(b::FlatBuffers.Builder) = FlatBuffers.endobject!(b) +struct Tensor <: FlatBuffers.Table + bytes::Vector{UInt8} + pos::Base.Int +end + +Base.propertynames(x::Tensor) = () +Base.getproperty(x::Tensor, field::Symbol) = nothing + +struct SparseTensor <: FlatBuffers.Table + bytes::Vector{UInt8} + pos::Base.Int +end + +Base.propertynames(x::SparseTensor) = () +Base.getproperty(x::SparseTensor, field::Symbol) = nothing + function MessageHeader(b::UInt8) b == 1 && return Schema b == 2 && return DictionaryBatch b == 3 && return RecordBatch - # b == 4 && return Tensor - # b == 5 && return SparseTensor + b == 4 && return Tensor + b == 5 && return SparseTensor return nothing end @@ -170,8 +186,8 @@ function MessageHeader(::Base.Type{T})::Int16 where {T} T == Schema && return 1 T == DictionaryBatch && return 2 T == RecordBatch && return 3 - # T == Tensor && return 4 - # T == SparseTensor && return 5 + T == Tensor && return 4 + T == SparseTensor && return 5 return 0 end diff --git a/src/metadata/overlay.jl b/src/metadata/overlay.jl new file mode 100644 index 0000000..6c00ea0 --- /dev/null +++ b/src/metadata/overlay.jl @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +_metadata_entries(metadata) = metadata isa AbstractVector ? metadata : pairs(metadata) + +function _normalize_metadata_overlay(metadata) + metadata === nothing && return nothing + return toidict( + String(first(entry)) => String(last(entry)) for entry in _metadata_entries(metadata) + ) +end + +function _merge_metadata_overlays(metadata_sources...) + merged = Dict{String,String}() + for metadata in metadata_sources + metadata === nothing && continue + for entry in _metadata_entries(metadata) + merged[String(first(entry))] = String(last(entry)) + end + end + return isempty(merged) ? nothing : toidict(pairs(merged)) +end + +struct MetadataOverlayVector{T,V<:AbstractVector{T},M} <: AbstractVector{T} + data::V + metadata::M +end + +Base.IndexStyle(::Type{<:MetadataOverlayVector{T,V}}) where {T,V} = Base.IndexStyle(V) +Base.size(x::MetadataOverlayVector) = size(x.data) +Base.axes(x::MetadataOverlayVector) = axes(x.data) +Base.length(x::MetadataOverlayVector) = length(x.data) +Base.getindex(x::MetadataOverlayVector, i::Int) = x.data[i] +Base.iterate(x::MetadataOverlayVector, state...) = iterate(x.data, state...) +getmetadata(x::MetadataOverlayVector) = x.metadata + +struct MetadataOverlayTable{N,C,M} + columns::NamedTuple{N,C} + metadata::M +end + +function Base.getproperty(x::MetadataOverlayTable, name::Symbol) + if name === :columns || name === :metadata + return getfield(x, name) + end + columns = getfield(x, :columns) + if hasproperty(columns, name) + return getproperty(columns, name) + end + return getfield(x, name) +end + +function Base.propertynames(x::MetadataOverlayTable, private::Bool=false) + column_names = propertynames(getfield(x, :columns)) + return private ? (:columns, :metadata, column_names...) : column_names +end + +Tables.istable(::Type{<:MetadataOverlayTable}) = true +Tables.columnaccess(::Type{<:MetadataOverlayTable}) = true +Tables.columns(x::MetadataOverlayTable) = getfield(x, :columns) +Tables.schema(x::MetadataOverlayTable) = Tables.schema(getfield(x, :columns)) +getmetadata(x::MetadataOverlayTable) = getfield(x, :metadata) + +function _column_metadata_overlay(table_like) + merged = Dict{Symbol,Any}() + for name in Tables.schema(table_like).names + metadata = + _normalize_metadata_overlay(getmetadata(Tables.getcolumn(table_like, name))) + metadata === nothing || (merged[name] = metadata) + end + return merged +end + +function _merge_column_metadata_overlays(table_like, colmetadata) + merged = _column_metadata_overlay(table_like) + colmetadata === nothing && return merged + for (name, metadata) in pairs(colmetadata) + symbol_name = Symbol(name) + merged_metadata = + _merge_metadata_overlays(get(merged, symbol_name, nothing), metadata) + merged_metadata === nothing || (merged[symbol_name] = merged_metadata) + end + return merged +end + +function _metadata_overlay_table(columns::NamedTuple; metadata=nothing, colmetadata=nothing) + wrapped_columns = Pair{Symbol,Any}[] + for name in keys(columns) + column_metadata = isnothing(colmetadata) ? nothing : get(colmetadata, name, nothing) + push!( + wrapped_columns, + name => MetadataOverlayVector(columns[name], column_metadata), + ) + end + return MetadataOverlayTable((; wrapped_columns...), metadata) +end + +""" + Arrow.withmetadata(table_like; metadata=nothing, colmetadata=nothing) + +Return a lightweight Tables.jl-compatible wrapper around `table_like` that +preserves any existing Arrow schema/field metadata and overlays additional +schema `metadata` and column `colmetadata` for subsequent Arrow serialization. + +Both `metadata` and `colmetadata` follow the same shape accepted by +[`Arrow.write`](@ref): schema metadata must be an iterable of string-like pairs, +while `colmetadata` must map column names to iterables of string-like pairs. +When the source already carries metadata, overlay entries win on key conflicts. +""" +function withmetadata(columns::NamedTuple; metadata=nothing, colmetadata=nothing) + normalized_metadata = _normalize_metadata_overlay(metadata) + normalized_colmetadata = if isnothing(colmetadata) + nothing + else + Dict( + Symbol(name) => _normalize_metadata_overlay(column_metadata) for + (name, column_metadata) in pairs(colmetadata) + ) + end + if normalized_metadata === nothing && isnothing(normalized_colmetadata) + return columns + end + return _metadata_overlay_table( + columns; + metadata=normalized_metadata, + colmetadata=normalized_colmetadata, + ) +end + +function withmetadata(table_like; metadata=nothing, colmetadata=nothing) + merged_metadata = _merge_metadata_overlays(getmetadata(table_like), metadata) + merged_colmetadata = _merge_column_metadata_overlays(table_like, colmetadata) + if merged_metadata === nothing && isempty(merged_colmetadata) + return table_like + end + return _metadata_overlay_table( + Tables.columntable(table_like); + metadata=merged_metadata, + colmetadata=isempty(merged_colmetadata) ? nothing : merged_colmetadata, + ) +end diff --git a/src/table.jl b/src/table.jl index de8bfc3..e2df728 100644 --- a/src/table.jl +++ b/src/table.jl @@ -28,6 +28,10 @@ tobytes(io::IO) = Base.read(io) tobytes(io::IOStream) = Mmap.mmap(io) tobytes(file_path) = open(tobytes, file_path, "r") +rejectunsupported(field::Meta.Field) = + (rejectunsupported(field.type); foreach(rejectunsupported, field.children)) +rejectunsupported(x) = nothing + struct BatchIterator bytes::Vector{UInt8} startpos::Int @@ -178,15 +182,19 @@ function Base.iterate(x::Stream, (pos, id)=(1, 0)) end batch, (pos, id) = state header = batch.msg.header - if isnothing(x.schema) && !isa(header, Meta.Schema) + if header isa Meta.Tensor + throw(ArgumentError(TENSOR_UNSUPPORTED)) + elseif header isa Meta.SparseTensor + throw(ArgumentError(SPARSE_TENSOR_UNSUPPORTED)) + elseif isnothing(x.schema) && !isa(header, Meta.Schema) throw(ArgumentError("first arrow ipc message MUST be a schema message")) - end - if header isa Meta.Schema + elseif header isa Meta.Schema if isnothing(x.schema) x.schema = header # assert endianness? # store custom_metadata? for (i, field) in enumerate(x.schema.fields) + rejectunsupported(field) push!(x.names, Symbol(field.name)) push!( x.types, @@ -264,6 +272,10 @@ function Base.iterate(x::Stream, (pos, id)=(1, 0)) push!(columns, vec) end break + elseif header isa Meta.Tensor + throw(ArgumentError(TENSOR_UNSUPPORTED)) + elseif header isa Meta.SparseTensor + throw(ArgumentError(SPARSE_TENSOR_UNSUPPORTED)) else throw(ArgumentError("unsupported arrow message type: $(typeof(header))")) end @@ -421,17 +433,70 @@ Tables.columnnames(t::Table) = names(t) Tables.getcolumn(t::Table, i::Int) = columns(t)[i] Tables.getcolumn(t::Table, nm::Symbol) = lookup(t)[nm] +struct MetadataVector{T,A<:AbstractVector{T},M} <: AbstractVector{T} + data::A + metadata::M +end + +Base.IndexStyle(::Type{<:MetadataVector}) = Base.IndexLinear() +Base.size(x::MetadataVector) = size(x.data) +Base.axes(x::MetadataVector) = axes(x.data) +Base.length(x::MetadataVector) = length(x.data) +Base.getindex(x::MetadataVector, i::Int) = getindex(x.data, i) +Base.iterate(x::MetadataVector) = iterate(x.data) +Base.iterate(x::MetadataVector, state) = iterate(x.data, state) +getmetadata(x::MetadataVector) = x.metadata + +_metadatavectordata(x::MetadataVector) = x.data +_metadatavectordata(x) = x +_wrapmetadata(data, metadata) = metadata === nothing ? data : MetadataVector(data, metadata) + struct TablePartitions table::Table npartitions::Int end +Base.IteratorSize(::Type{TablePartitions}) = Base.HasLength() +Base.length(tp::TablePartitions) = tp.npartitions + +function _partitionarrays(col::MetadataVector) + data = getfield(col, :data) + return data isa ChainedVector ? data.arrays : nothing +end + +_partitionarrays(col) = col isa ChainedVector ? col.arrays : _wrappedpartitionarrays(col) + +function _wrappedpartitionarrays(col) + if hasfield(typeof(col), :data) + data = getfield(col, :data) + data isa ChainedVector && return data.arrays + end + return nothing +end + +_partitioncolumn(col::MetadataVector, i::Int) = + MetadataVector(getfield(col, :data).arrays[i], getfield(col, :metadata)) + +_partitioncolumn(col, i::Int) = + col isa ChainedVector ? col.arrays[i] : _wrappedpartitioncolumn(col, i) + +function _wrappedpartitioncolumn(col, i::Int) + if hasfield(typeof(col), :data) && hasfield(typeof(col), :metadata) + data = getfield(col, :data) + if data isa ChainedVector + wrapper = getfield(parentmodule(typeof(col)), nameof(typeof(col))) + return wrapper(data.arrays[i], getfield(col, :metadata)) + end + end + return col +end + function TablePartitions(table::Table) cols = columns(table) npartitions = if length(cols) == 0 0 - elseif cols[1] isa ChainedVector - length(cols[1].arrays) + elseif (arrays = _partitionarrays(cols[1])) !== nothing + length(arrays) else 1 end @@ -442,7 +507,7 @@ function Base.iterate(tp::TablePartitions, i=1) i > tp.npartitions && return nothing tp.npartitions == 1 && return tp.table, i + 1 cols = columns(tp.table) - newcols = AbstractVector[cols[j].arrays[i] for j = 1:length(cols)] + newcols = AbstractVector[_partitioncolumn(cols[j], i) for j = 1:length(cols)] nms = names(tp.table) tbl = Table( nms, @@ -487,6 +552,7 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true) # store custom_metadata? if sch === nothing for (i, field) in enumerate(header.fields) + rejectunsupported(field) push!(names(t), Symbol(field.name)) # recursively find any dictionaries for any fields getdictionaries!(dictencoded, field) @@ -578,6 +644,10 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true) ), ) rbi += 1 + elseif header isa Meta.Tensor + throw(ArgumentError(TENSOR_UNSUPPORTED)) + elseif header isa Meta.SparseTensor + throw(ArgumentError(SPARSE_TENSOR_UNSUPPORTED)) else throw(ArgumentError("unsupported arrow message type: $(typeof(header))")) end @@ -732,6 +802,18 @@ const ListTypes = const LargeLists = Union{Meta.LargeUtf8,Meta.LargeBinary,Meta.LargeList,Meta.LargeListView} const ViewTypes = Union{Meta.Utf8View,Meta.BinaryView,Meta.ListView,Meta.LargeListView} +@inline function _viewbuffercount(validity, views, declared::Integer) + count = Int(declared) + for i in eachindex(views) + validity[i] || continue + v = @inbounds views[i] + if !_viewisinline(v.length) + count = max(count, Int(v.bufindex) + 1) + end + end + return count +end + function build(field::Meta.Field, batch, rb, de, nodeidx, bufferidx, varbufferidx, convert) d = field.dictionary if d !== nothing @@ -891,6 +973,32 @@ function build( varbufferidx end +function build( + f::Meta.Field, + x::Meta.RunEndEncoded, + batch, + rb, + de, + nodeidx, + bufferidx, + varbufferidx, + convert, +) + @debug "building array: x = $x" + len = rb.nodes[nodeidx].length + nodeidx += 1 + meta = buildmetadata(f.custom_metadata) + T = juliaeltype(f, meta, convert) + run_ends, nodeidx, bufferidx, varbufferidx = + build(f.children[1], batch, rb, de, nodeidx, bufferidx, varbufferidx, false) + values, nodeidx, bufferidx, varbufferidx = + build(f.children[2], batch, rb, de, nodeidx, bufferidx, varbufferidx, convert) + return _makerunendencoded(T, run_ends, values, len, meta), + nodeidx, + bufferidx, + varbufferidx +end + function build( f::Meta.Field, L::ViewTypes, @@ -910,7 +1018,8 @@ function build( inline = reinterpret(UInt8, views) # reuse the (possibly realigned) memory backing `views` bufferidx += 1 buffers = Vector{UInt8}[] - for i = 1:rb.variadicBufferCounts[varbufferidx] + nvariadic = _viewbuffercount(validity, views, rb.variadicBufferCounts[varbufferidx]) + for i = 1:nvariadic buffer = rb.buffers[bufferidx] _, A = reinterp(UInt8, batch, buffer, rb.compression) push!(buffers, A) diff --git a/src/utils.jl b/src/utils.jl index 8e2dfee..05a297c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -37,6 +37,33 @@ end # efficient writing of arrays writearray(io, col) = writearray(io, maybemissing(eltype(col)), col) +function _writearrayfallback(io::IO, ::Type{T}, col) where {T} + n = 0 + data = Vector{UInt8}(undef, sizeof(col)) + buf = IOBuffer(data; write=true) + for x in col + n += Base.write(buf, coalesce(x, ArrowTypes.default(T))) + end + n = Base.write(io, take!(buf)) + return n +end + +@inline function _writearraycontiguous(io::IO, ::Type{T}, data) where {T} + return Base.unsafe_write(io, pointer(data), sizeof(T) * length(data)) +end + +@inline function _contiguoustoarrowdata(::Type{T}, col::ArrowTypes.ToArrow) where {T} + ArrowTypes._needsconvert(col) && return nothing + data = ArrowTypes._sourcedata(col) + strides(data) == (1,) || return nothing + if data isa AbstractVector{T} + return isbitstype(T) ? data : nothing + elseif isbitstype(T) && data isa AbstractVector{Union{T,Missing}} + return data + end + return nothing +end + function writearray(io::IO, ::Type{T}, col) where {T} if col isa Vector{T} n = Base.write(io, col) @@ -51,17 +78,17 @@ function writearray(io::IO, ::Type{T}, col) where {T} n += writearray(io, T, A) end else - n = 0 - data = Vector{UInt8}(undef, sizeof(col)) - buf = IOBuffer(data; write=true) - for x in col - n += Base.write(buf, coalesce(x, ArrowTypes.default(T))) - end - n = Base.write(io, take!(buf)) + n = _writearrayfallback(io, T, col) end return n end +function writearray(io::IO, ::Type{T}, col::ArrowTypes.ToArrow) where {T} + data = _contiguoustoarrowdata(T, col) + isnothing(data) || return _writearraycontiguous(io, T, data) + return _writearrayfallback(io, T, col) +end + getbit(v::UInt8, n::Integer) = (v & (1 << (n - 1))) > 0x00 function setbit(v::UInt8, b::Bool, n::Integer) @@ -120,6 +147,11 @@ function getrb(filebytes) # FlatBuffers.getrootas(Meta.Message, filebytes, rb.offset) end +@inline function messagebytes(msg, alignment) + metalen = padding(length(msg.msgflatbuf), alignment) + return 8 + metalen + msg.bodylen +end + function readmessage(filebytes, off=9) @assert readbuffer(filebytes, off, UInt32) === 0xFFFFFFFF len = readbuffer(filebytes, off + 4, Int32) @@ -127,7 +159,251 @@ function readmessage(filebytes, off=9) FlatBuffers.getrootas(Meta.Message, filebytes, off + 8) end +@inline _issinglepartition(parts) = parts isa Tuple && length(parts) == 1 + +@inline function _directtobuffercoleligible(col) + T = Base.nonmissingtype(eltype(col)) + T <: AbstractString && return false + T <: Base.CodeUnits && return false + K = ArrowTypes.ArrowKind(ArrowTypes.ArrowType(T)) + return !(K isa ArrowTypes.ListKind) +end + +@inline function _directtobufferstringonly(col) + T = Base.nonmissingtype(eltype(col)) + return T <: AbstractString +end + +@inline function _directtobufferbinaryonly(col) + return eltype(col) <: Base.CodeUnits +end + +@inline function _directstreamcoleligible(col) + return !(col isa DictEncode) && + DataAPI.refarray(col) === col && + (_directtobufferstringonly(col) || _directtobufferbinaryonly(col)) +end + +function _directtobuffereligible(part) + tblcols = Tables.columns(part) + sch = Tables.schema(tblcols) + ncols = 0 + singlecolspecial = false + allnonstrings = true + Tables.eachcolumn(sch, tblcols) do col, _, _ + ncols += 1 + eligible = _directtobuffercoleligible(col) + allnonstrings &= eligible + singlecolspecial = + ncols == 1 && (_directtobufferstringonly(col) || _directtobufferbinaryonly(col)) + end + return allnonstrings || (ncols == 1 && singlecolspecial) +end + +@inline function _directstreameligible(part) + tblcols = Tables.columns(part) + sch = Tables.schema(tblcols) + ncols = 0 + singlecolspecial = false + Tables.eachcolumn(sch, tblcols) do col, _, _ + ncols += 1 + singlecolspecial = ncols == 1 && _directstreamcoleligible(col) + end + return ncols == 1 && singlecolspecial +end + +@inline _partitionsinspectable(parts) = + parts isa Tuple || parts isa AbstractVector || parts isa Tables.Partitioner + +@inline function _directtobuffersizehint( + cols, + dictmsgs, + schmsg, + recbatchmsg, + endmsg, + alignment, +) + for col in Tables.Columns(cols) + if col isa Map + return + messagebytes(schmsg, alignment) + + sum(msg -> messagebytes(msg, alignment), dictmsgs; init=0) + + messagebytes(recbatchmsg, alignment) + + messagebytes(endmsg, alignment) + end + end + return nothing +end + +function _writedictionarymessages!(io, blocks, schref, alignment, dictencodings) + isempty(dictencodings) && return + des = sort!(collect(dictencodings); by=x -> x.first, rev=true) + for (id, delock) in des + de = delock.value + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + msg = makedictionarybatchmsg(dictsch, (col=de.data,), id, false, alignment) + Base.write(io, msg, blocks, schref, alignment) + end + return +end + +function _writedictionarydeltas!(io, blocks, schref, alignment, deltas) + isempty(deltas) && return + for de in deltas + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + msg = makedictionarybatchmsg(dictsch, (col=de.data,), de.id, true, alignment) + Base.write(io, msg, blocks, schref, alignment) + end + return +end + +@inline function _directstreamstate(parts) + _partitionsinspectable(parts) || return nothing + firststate = iterate(parts) + isnothing(firststate) && return nothing + firstpart, state = firststate + isnothing(iterate(parts, state)) && return nothing + return firstpart, state +end + +function _directtobuffer(part, source, kwargs) + largelists = get(kwargs, :largelists, false) + compress = get(kwargs, :compress, nothing) + denseunions = get(kwargs, :denseunions, true) + dictencode = get(kwargs, :dictencode, false) + dictencodenested = get(kwargs, :dictencodenested, false) + alignment = Int32(get(kwargs, :alignment, 8)) + maxdepth = get(kwargs, :maxdepth, DEFAULT_MAX_DEPTH) + metadata = get(kwargs, :metadata, getmetadata(source)) + colmetadata = get(kwargs, :colmetadata, nothing) + + tblcols = Tables.columns(part) + dictencodings = Dict{Int64,Any}() + cols = toarrowtable( + tblcols, + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + metadata, + colmetadata, + ) + sch = Tables.schema(cols) + schmsg = makeschemamsg(sch, cols) + dictmsgs = if isempty(dictencodings) + Message[] + else + des = sort!(collect(dictencodings); by=x -> x.first, rev=true) + [ + begin + de = delock.value + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + makedictionarybatchmsg(dictsch, (col=de.data,), id, false, alignment) + end for (id, delock) in des + ] + end + recbatchmsg = makerecordbatchmsg(sch, cols, alignment) + endmsg = Message(UInt8[], nothing, 0, true, false, Meta.Schema) + sizehint = + _directtobuffersizehint(cols, dictmsgs, schmsg, recbatchmsg, endmsg, alignment) + io = isnothing(sizehint) ? IOBuffer() : IOBuffer(; sizehint=sizehint) + blocks = (Block[], Block[]) + schref = Ref(sch) + Base.write(io, schmsg, blocks, schref, alignment) + foreach(msg -> Base.write(io, msg, blocks, schref, alignment), dictmsgs) + Base.write(io, recbatchmsg, blocks, schref, alignment) + Base.write(io, endmsg, blocks, schref, alignment) + seekstart(io) + return io +end + +function _directstreamwrite!(io::IO, firstpart, state, parts, source, kwargs) + largelists = get(kwargs, :largelists, false) + compress = get(kwargs, :compress, nothing) + denseunions = get(kwargs, :denseunions, true) + dictencode = get(kwargs, :dictencode, false) + dictencodenested = get(kwargs, :dictencodenested, false) + alignment = Int32(get(kwargs, :alignment, 8)) + maxdepth = get(kwargs, :maxdepth, DEFAULT_MAX_DEPTH) + metadata = get(kwargs, :metadata, getmetadata(source)) + colmetadata = get(kwargs, :colmetadata, nothing) + + dictencodings = Dict{Int64,Any}() + firstcols = toarrowtable( + Tables.columns(firstpart), + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + metadata, + colmetadata, + ) + sch = Tables.schema(firstcols) + schmsg = makeschemamsg(sch, firstcols) + blocks = (Block[], Block[]) + schref = Ref(sch) + Base.write(io, schmsg, blocks, schref, alignment) + _writedictionarymessages!(io, blocks, schref, alignment, dictencodings) + Base.write(io, makerecordbatchmsg(sch, firstcols, alignment), blocks, schref, alignment) + + next = iterate(parts, state) + while !isnothing(next) + part, state = next + cols = toarrowtable( + Tables.columns(part), + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + metadata, + colmetadata, + ) + Tables.schema(cols) == sch || + throw(ArgumentError("all partitions must have the exact same Tables.Schema")) + _writedictionarydeltas!(io, blocks, schref, alignment, cols.dictencodingdeltas) + Base.write(io, makerecordbatchmsg(sch, cols, alignment), blocks, schref, alignment) + next = iterate(parts, state) + end + Base.write( + io, + Message(UInt8[], nothing, 0, true, false, Meta.Schema), + blocks, + schref, + alignment, + ) + return io +end + +function _directstreamtobuffer(firstpart, state, parts, source, kwargs) + io = IOBuffer() + _directstreamwrite!(io, firstpart, state, parts, source, kwargs) + seekstart(io) + return io +end + function tobuffer(data; kwargs...) + parts = Tables.partitions(data) + if !get(kwargs, :file, false) + if _issinglepartition(parts) && _directtobuffereligible(parts[1]) + return _directtobuffer(parts[1], data, kwargs) + else + streamstate = _directstreamstate(parts) + if !isnothing(streamstate) + firstpart, state = streamstate + _directstreameligible(firstpart) && + return _directstreamtobuffer(firstpart, state, parts, data, kwargs) + end + end + end io = IOBuffer() write(io, data; kwargs...) seekstart(io) diff --git a/src/write.jl b/src/write.jl index 4c3800f..a25d4af 100644 --- a/src/write.jl +++ b/src/write.jl @@ -403,6 +403,17 @@ function Base.close(writer::Writer) end function write(io::IO, tbl; kwargs...) + if !get(kwargs, :file, false) + parts = Tables.partitions(tbl) + streamstate = _directstreamstate(parts) + if !isnothing(streamstate) + firstpart, state = streamstate + if _directstreameligible(firstpart) + _directstreamwrite!(io, firstpart, state, parts, tbl, kwargs) + return io + end + end + end open(Writer, io; file=false, kwargs...) do writer write(writer, tbl) end diff --git a/test/Project.toml b/test/Project.toml index c2e02aa..f5e62b1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,15 +6,17 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. [deps] +Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" @@ -25,15 +27,17 @@ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" +TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +gRPCClient = "aaca4a50-36af-4a1d-b878-4c443f2061ad" [compat] ArrowTypes = "2.3" @@ -44,8 +48,10 @@ FilePathsBase = "0.9" JSON3 = "1" OffsetArrays = "1" PooledArrays = "1" -StructTypes = "1" +ProtoBuf = "~1.2.1" SentinelArrays = "1" +StructTypes = "1" Tables = "1" TestSetExtensions = "3" TimeZones = "1" +gRPCClient = "1" diff --git a/test/flight.jl b/test/flight.jl new file mode 100644 index 0000000..fdd8f22 --- /dev/null +++ b/test/flight.jl @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +using gRPCClient +using Tables + +include("flight/support.jl") +include("flight/header_interop.jl") +include("flight/handshake_interop.jl") +include("flight/tls_interop.jl") +include("flight/poll_interop.jl") +include("flight/client_surface.jl") +include("flight/server_core.jl") +include("flight/grpcserver_extension.jl") +include("flight/ipc_conversion.jl") +include("flight/ipc_schema_separation.jl") +include("flight/pyarrow_interop.jl") diff --git a/test/flight/client_surface.jl b/test/flight/client_surface.jl new file mode 100644 index 0000000..e50bdd9 --- /dev/null +++ b/test/flight/client_surface.jl @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("client_surface/support.jl") +include("client_surface/constructor_tests.jl") +include("client_surface/header_tls_tests.jl") +include("client_surface/protocol_client_tests.jl") + +@testset "Flight RPC client surface" begin + fixture = flight_client_surface_fixture() + flight_client_surface_test_constructors(fixture) + flight_client_surface_test_header_tls_helpers(fixture) + flight_client_surface_test_protocol_clients(fixture) +end diff --git a/test/flight/client_surface/constructor_tests.jl b/test/flight/client_surface/constructor_tests.jl new file mode 100644 index 0000000..ca47f02 --- /dev/null +++ b/test/flight/client_surface/constructor_tests.jl @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_test_constructors(fixture) + client = fixture.client + + @test client.host == "localhost" + @test client.port == 8815 + @test client.secure + @test client.deadline == 30.0 + @test client.keepalive == 15.0 + @test client.max_send_message_length == 1024 + @test client.max_recieve_message_length == 2048 + @test isempty(client.headers) + @test isnothing(client.tls_root_certs) + @test isnothing(client.cert_chain) + @test isnothing(client.private_key) + @test isnothing(client.key_password) + @test !client.disable_server_verification + + uri_client = Arrow.Flight.Client("grpc://127.0.0.1:31337") + @test uri_client.host == "127.0.0.1" + @test uri_client.port == 31337 + @test !uri_client.secure + + tls_client = Arrow.Flight.Client("grpc+tls://example.com:9443") + @test tls_client.host == "example.com" + @test tls_client.port == 9443 + @test tls_client.secure + + location_client = + Arrow.Flight.Client(fixture.protocol.Location("https://demo.example:8443")) + @test location_client.host == "demo.example" + @test location_client.port == 8443 + @test location_client.secure + + @test_throws ArgumentError Arrow.Flight.Client("grpc://missing-port") +end diff --git a/test/flight/client_surface/header_tls_tests.jl b/test/flight/client_surface/header_tls_tests.jl new file mode 100644 index 0000000..326684d --- /dev/null +++ b/test/flight/client_surface/header_tls_tests.jl @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_test_header_tls_helpers(fixture) + client = fixture.client + + tls_client = Arrow.Flight.Client( + "grpc+tls://secure.example:9443"; + tls_root_certs="/tmp/root.pem", + cert_chain="/tmp/client.pem", + private_key="/tmp/client.key", + key_password="secret", + disable_server_verification=true, + ) + @test tls_client.tls_root_certs == "/tmp/root.pem" + @test tls_client.cert_chain == "/tmp/client.pem" + @test tls_client.private_key == "/tmp/client.key" + @test tls_client.key_password == "secret" + @test tls_client.disable_server_verification + + header_client = Arrow.Flight.withheaders( + client, + "authorization" => "Bearer token1234", + "x-trace-id" => "trace-1", + ) + @test header_client.headers == + ["authorization" => "Bearer token1234", "x-trace-id" => "trace-1"] + @test header_client.host == client.host + @test header_client.grpc === client.grpc + @test header_client.disable_server_verification == client.disable_server_verification + + binary_header_client = + Arrow.Flight.withheaders(client, "auth-token-bin" => UInt8[0x00, 0xff, 0x41]) + @test binary_header_client.headers == ["auth-token-bin" => UInt8[0x00, 0xff, 0x41]] + @test Arrow.Flight._header_lines(binary_header_client.headers) == + ["auth-token-bin: AP9B"] + + token_client = Arrow.Flight.withtoken(client, UInt8[0x01, 0x02]) + @test token_client.headers == ["auth-token-bin" => UInt8[0x01, 0x02]] + @test Arrow.Flight._header_lines(token_client.headers) == ["auth-token-bin: AQI="] + + invalid_binary_header_client = + Arrow.Flight.withheaders(client, "x-binary" => UInt8[0x00]) + @test_throws ArgumentError Arrow.Flight._header_lines( + invalid_binary_header_client.headers, + ) +end diff --git a/test/flight/client_surface/protocol_client_tests.jl b/test/flight/client_surface/protocol_client_tests.jl new file mode 100644 index 0000000..eae3b8f --- /dev/null +++ b/test/flight/client_surface/protocol_client_tests.jl @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_test_protocol_clients(fixture) + client = fixture.client + + @test isdefined(fixture.protocol, :FlightService_Handshake_Client) + @test isdefined(fixture.protocol, :FlightService_ListFlights_Client) + @test isdefined(fixture.protocol, :FlightService_GetFlightInfo_Client) + @test isdefined(fixture.protocol, :FlightService_PollFlightInfo_Client) + @test isdefined(fixture.protocol, :FlightService_GetSchema_Client) + @test isdefined(fixture.protocol, :FlightService_DoGet_Client) + @test isdefined(fixture.protocol, :FlightService_DoPut_Client) + @test isdefined(fixture.protocol, :FlightService_DoExchange_Client) + @test isdefined(fixture.protocol, :FlightService_DoAction_Client) + @test isdefined(fixture.protocol, :FlightService_ListActions_Client) + + @test Arrow.Flight._handshake_client(client).path == + "/arrow.flight.protocol.FlightService/Handshake" + @test Arrow.Flight._listflights_client(client).path == + "/arrow.flight.protocol.FlightService/ListFlights" + @test Arrow.Flight._getflightinfo_client(client).path == + "/arrow.flight.protocol.FlightService/GetFlightInfo" + @test Arrow.Flight._pollflightinfo_client(client).path == + "/arrow.flight.protocol.FlightService/PollFlightInfo" + @test Arrow.Flight._getschema_client(client).path == + "/arrow.flight.protocol.FlightService/GetSchema" + @test Arrow.Flight._doget_client(client).path == + "/arrow.flight.protocol.FlightService/DoGet" + @test Arrow.Flight._doput_client(client).path == + "/arrow.flight.protocol.FlightService/DoPut" + @test Arrow.Flight._doexchange_client(client).path == + "/arrow.flight.protocol.FlightService/DoExchange" + @test Arrow.Flight._doaction_client(client).path == + "/arrow.flight.protocol.FlightService/DoAction" + @test Arrow.Flight._listactions_client(client).path == + "/arrow.flight.protocol.FlightService/ListActions" +end diff --git a/test/flight/client_surface/support.jl b/test/flight/client_surface/support.jl new file mode 100644 index 0000000..eaa065d --- /dev/null +++ b/test/flight/client_surface/support.jl @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_fixture() + client = Arrow.Flight.Client( + "localhost", + 8815; + secure=true, + deadline=30, + keepalive=15, + max_send_message_length=1024, + max_recieve_message_length=2048, + ) + return (; client, protocol=Arrow.Flight.Protocol) +end diff --git a/test/flight/grpcserver_extension.jl b/test/flight/grpcserver_extension.jl new file mode 100644 index 0000000..55c7a64 --- /dev/null +++ b/test/flight/grpcserver_extension.jl @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("grpcserver_extension/support.jl") +include("grpcserver_extension/descriptor_tests.jl") +include("grpcserver_extension/unary_tests.jl") +include("grpcserver_extension/streaming_tests.jl") + +@testset "Flight gRPCServer extension" begin + grpcserver = FlightTestSupport.load_grpcserver() + if isnothing(grpcserver) + @test true + else + protocol = Arrow.Flight.Protocol + fixture = grpcserver_extension_fixture(protocol) + service = grpcserver_extension_service(protocol, fixture) + metadata = grpcserver_extension_metadata() + + grpcserver_extension_test_descriptor(grpcserver, service) + grpcserver_extension_test_unary(grpcserver, service, fixture, metadata) + grpcserver_extension_test_streaming(grpcserver, service, fixture, metadata) + end +end diff --git a/test/flight/grpcserver_extension/bidi_streaming_tests.jl b/test/flight/grpcserver_extension/bidi_streaming_tests.jl new file mode 100644 index 0000000..e2f1aff --- /dev/null +++ b/test/flight/grpcserver_extension/bidi_streaming_tests.jl @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_bidi_streaming(grpcserver, service, fixture, metadata) + grpc_descriptor = grpcserver.service_descriptor(service) + protocol = Arrow.Flight.Protocol + + handshake_messages, handshake_closed, handshake_stream = grpcserver_capture_bidi_stream( + grpcserver, + protocol.HandshakeRequest, + protocol.HandshakeResponse, + fixture.handshake_requests, + ) + grpc_descriptor.methods["Handshake"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/Handshake"; + metadata=metadata, + ), + handshake_stream, + ) + @test handshake_closed[] + @test length(handshake_messages) == 1 + @test handshake_messages[1].payload == b"native-token" + + doput_messages, doput_closed, doput_stream = grpcserver_capture_bidi_stream( + grpcserver, + protocol.FlightData, + protocol.PutResult, + fixture.messages, + ) + grpc_descriptor.methods["DoPut"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoPut"; + metadata=metadata, + ), + doput_stream, + ) + @test doput_closed[] + @test length(doput_messages) == 1 + @test String(doput_messages[1].app_metadata) == "stored" + + doexchange_messages, doexchange_closed, doexchange_stream = + grpcserver_capture_bidi_stream( + grpcserver, + protocol.FlightData, + protocol.FlightData, + fixture.exchange_messages, + ) + grpc_descriptor.methods["DoExchange"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoExchange"; + metadata=metadata, + ), + doexchange_stream, + ) + @test doexchange_closed[] + @test length(doexchange_messages) == length(fixture.exchange_messages) + doexchange_table = Arrow.Flight.table(doexchange_messages) + @test doexchange_table.id == [10] + @test doexchange_table.name == ["ten"] + @test Arrow.getmetadata(doexchange_table)["dataset"] == "exchange" + @test Arrow.getmetadata(doexchange_table.name)["lang"] == "exchange" + + failing_service = Arrow.Flight.Service( + doexchange=(ctx, request, response) -> + throw(ArgumentError("bidi streaming failed before first response")), + ) + failing_descriptor = grpcserver.service_descriptor(failing_service) + failing_messages, failing_closed, failing_stream = grpcserver_capture_bidi_stream( + grpcserver, + protocol.FlightData, + protocol.FlightData, + fixture.exchange_messages, + ) + failure = try + failing_descriptor.methods["DoExchange"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoExchange"; + metadata=metadata, + ), + failing_stream, + ) + nothing + catch err + err + end + @test failure isa ArgumentError + @test occursin( + "bidi streaming failed before first response", + sprint(showerror, failure), + ) + @test !failing_closed[] + @test isempty(failing_messages) +end diff --git a/test/flight/grpcserver_extension/descriptor_tests.jl b/test/flight/grpcserver_extension/descriptor_tests.jl new file mode 100644 index 0000000..631891b --- /dev/null +++ b/test/flight/grpcserver_extension/descriptor_tests.jl @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_descriptor(grpcserver, service) + grpc_descriptor = grpcserver.service_descriptor(service) + @test Base.get_extension(Arrow, :ArrowgRPCServerExt) !== nothing + @test grpc_descriptor.name == "arrow.flight.protocol.FlightService" + @test haskey(grpc_descriptor.methods, "GetFlightInfo") + @test haskey(grpc_descriptor.methods, "DoGet") + @test haskey(grpc_descriptor.methods, "DoExchange") + @test grpc_descriptor.methods["GetFlightInfo"].method_type == + grpcserver.MethodType.UNARY + @test grpc_descriptor.methods["DoGet"].method_type == + grpcserver.MethodType.SERVER_STREAMING + @test grpc_descriptor.methods["DoExchange"].method_type == + grpcserver.MethodType.BIDI_STREAMING + @test grpc_descriptor.methods["DoGet"].input_type == "arrow.flight.protocol.Ticket" + @test grpc_descriptor.methods["DoGet"].output_type == "arrow.flight.protocol.FlightData" + return grpc_descriptor +end diff --git a/test/flight/grpcserver_extension/server_streaming_tests.jl b/test/flight/grpcserver_extension/server_streaming_tests.jl new file mode 100644 index 0000000..7eb8b45 --- /dev/null +++ b/test/flight/grpcserver_extension/server_streaming_tests.jl @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_server_streaming(grpcserver, service, fixture, metadata) + grpc_descriptor = grpcserver.service_descriptor(service) + protocol = Arrow.Flight.Protocol + + doget_messages, doget_closed, doget_stream = + grpcserver_capture_server_stream(grpcserver, protocol.FlightData) + grpc_descriptor.methods["DoGet"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoGet"; + metadata=metadata, + ), + fixture.ticket, + doget_stream, + ) + @test doget_closed[] + @test length(doget_messages) == length(fixture.messages) + doget_table = Arrow.Flight.table(doget_messages; schema=fixture.info) + @test doget_table.name == ["one", "two", "three"] + @test Arrow.getmetadata(doget_table)["dataset"] == "native" + @test Arrow.getmetadata(doget_table.name)["lang"] == "en" + + doget_any_messages = Any[] + doget_any_closed = Ref(false) + doget_any_stream = grpcserver.ServerStream{Any}( + (message, compress) -> begin + @test compress + push!(doget_any_messages, message) + end, + () -> (doget_any_closed[] = true), + ) + grpc_descriptor.methods["DoGet"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoGet"; + metadata=metadata, + ), + fixture.ticket, + doget_any_stream, + ) + @test doget_any_closed[] + @test length(doget_any_messages) == length(fixture.messages) + @test all(message -> message isa protocol.FlightData, doget_any_messages) + + actions_messages, actions_closed, actions_stream = + grpcserver_capture_server_stream(grpcserver, protocol.ActionType) + grpc_descriptor.methods["ListActions"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/ListActions"; + metadata=metadata, + ), + protocol.Empty(), + actions_stream, + ) + @test actions_closed[] + @test length(actions_messages) == 1 + @test actions_messages[1].var"#type" == "ping" + + action_messages, action_closed, action_stream = + grpcserver_capture_server_stream(grpcserver, protocol.Result) + grpc_descriptor.methods["DoAction"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoAction"; + metadata=metadata, + ), + protocol.Action("ping", UInt8[]), + action_stream, + ) + @test action_closed[] + @test length(action_messages) == 1 + @test String(action_messages[1].body) == "pong" + + failing_service = Arrow.Flight.Service( + doget=(ctx, req, response) -> + throw(ArgumentError("server streaming failed before first response")), + ) + failing_descriptor = grpcserver.service_descriptor(failing_service) + failing_messages, failing_closed, failing_stream = + grpcserver_capture_server_stream(grpcserver, protocol.FlightData) + failure = try + failing_descriptor.methods["DoGet"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoGet"; + metadata=metadata, + ), + fixture.ticket, + failing_stream, + ) + nothing + catch err + err + end + @test failure isa ArgumentError + @test occursin( + "server streaming failed before first response", + sprint(showerror, failure), + ) + @test !failing_closed[] + @test isempty(failing_messages) +end diff --git a/test/flight/grpcserver_extension/streaming_tests.jl b/test/flight/grpcserver_extension/streaming_tests.jl new file mode 100644 index 0000000..d82ba90 --- /dev/null +++ b/test/flight/grpcserver_extension/streaming_tests.jl @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("server_streaming_tests.jl") +include("bidi_streaming_tests.jl") + +function grpcserver_extension_test_streaming(grpcserver, service, fixture, metadata) + grpcserver_extension_test_server_streaming(grpcserver, service, fixture, metadata) + grpcserver_extension_test_bidi_streaming(grpcserver, service, fixture, metadata) +end diff --git a/test/flight/grpcserver_extension/support.jl b/test/flight/grpcserver_extension/support.jl new file mode 100644 index 0000000..7daf8a4 --- /dev/null +++ b/test/flight/grpcserver_extension/support.jl @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("support/fixture.jl") +include("support/service.jl") +include("support/context.jl") +include("support/streams.jl") diff --git a/test/flight/grpcserver_extension/support/context.jl b/test/flight/grpcserver_extension/support/context.jl new file mode 100644 index 0000000..2d46598 --- /dev/null +++ b/test/flight/grpcserver_extension/support/context.jl @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +grpcserver_extension_metadata() = + Dict{String,Union{String,Vector{UInt8}}}("authorization" => "Bearer native") + +function grpcserver_extension_context( + grpcserver, + method::AbstractString; + metadata=grpcserver_extension_metadata(), +) + return grpcserver.ServerContext(method=String(method), metadata=metadata) +end diff --git a/test/flight/grpcserver_extension/support/fixture.jl b/test/flight/grpcserver_extension/support/fixture.jl new file mode 100644 index 0000000..0bd0066 --- /dev/null +++ b/test/flight/grpcserver_extension/support/fixture.jl @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_fixture(protocol) + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + descriptor = + protocol.FlightDescriptor(descriptor_type.PATH, UInt8[], ["native", "dataset"]) + ticket = protocol.Ticket(b"native-ticket") + dataset_metadata = Dict("dataset" => "native") + dataset_colmetadata = Dict(:name => Dict("lang" => "en")) + messages = Arrow.Flight.flightdata( + Tables.partitioner(( + (id=Int64[1, 2], name=["one", "two"]), + (id=Int64[3], name=["three"]), + )); + descriptor=descriptor, + metadata=dataset_metadata, + colmetadata=dataset_colmetadata, + ) + schema_bytes = Arrow.Flight.schemaipc(first(messages)) + info = protocol.FlightInfo( + schema_bytes[5:end], + descriptor, + [protocol.FlightEndpoint(ticket, protocol.Location[], nothing, UInt8[])], + Int64(3), + Int64(-1), + false, + UInt8[], + ) + handshake_requests = [protocol.HandshakeRequest(UInt64(0), b"native-token")] + exchange_metadata = Dict("dataset" => "exchange") + exchange_colmetadata = Dict(:name => Dict("lang" => "exchange")) + exchange_messages = Arrow.Flight.flightdata( + Tables.partitioner(((id=Int64[10], name=["ten"]),)); + descriptor=descriptor, + metadata=exchange_metadata, + colmetadata=exchange_colmetadata, + ) + return ( + descriptor=descriptor, + ticket=ticket, + messages=messages, + schema_bytes=schema_bytes, + info=info, + handshake_requests=handshake_requests, + dataset_metadata=dataset_metadata, + dataset_colmetadata=dataset_colmetadata, + exchange_messages=exchange_messages, + exchange_metadata=exchange_metadata, + exchange_colmetadata=exchange_colmetadata, + ) +end diff --git a/test/flight/grpcserver_extension/support/service.jl b/test/flight/grpcserver_extension/support/service.jl new file mode 100644 index 0000000..fea4cc5 --- /dev/null +++ b/test/flight/grpcserver_extension/support/service.jl @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_service(protocol, fixture) + return Arrow.Flight.Service( + handshake=(ctx, request, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + incoming = collect(request) + @test length(incoming) == 1 + put!(response, protocol.HandshakeResponse(UInt64(0), incoming[1].payload)) + close(response) + return :handshake_ok + end, + getflightinfo=(ctx, req) -> begin + @test req.path == fixture.descriptor.path + return fixture.info + end, + getschema=(ctx, req) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + @test req.path == fixture.descriptor.path + return protocol.SchemaResult(fixture.schema_bytes[5:end]) + end, + doget=(ctx, req, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + @test req.ticket == fixture.ticket.ticket + Arrow.Flight.putflightdata!( + response, + Tables.partitioner(( + (id=Int64[1, 2], name=["one", "two"]), + (id=Int64[3], name=["three"]), + )); + descriptor=fixture.descriptor, + metadata=fixture.dataset_metadata, + colmetadata=fixture.dataset_colmetadata, + close=true, + ) + return :doget_ok + end, + listactions=(ctx, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + put!(response, protocol.ActionType("ping", "Ping action")) + close(response) + return :listactions_ok + end, + doaction=(ctx, action, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + @test action.var"#type" == "ping" + put!(response, protocol.Result(b"pong")) + close(response) + return :doaction_ok + end, + doput=(ctx, request, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + incoming = collect(Arrow.Flight.stream(request)) + @test length(incoming) == 2 + @test incoming[1].id == [1, 2] + @test incoming[1].name == ["one", "two"] + @test Arrow.getmetadata(incoming[1])["dataset"] == "native" + @test Arrow.getmetadata(incoming[1].name)["lang"] == "en" + @test incoming[2].id == [3] + @test incoming[2].name == ["three"] + put!(response, protocol.PutResult(b"stored")) + close(response) + return :doput_ok + end, + doexchange=(ctx, request, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + Arrow.Flight.putflightdata!(response, Arrow.Flight.stream(request); close=true) + return :doexchange_ok + end, + ) +end diff --git a/test/flight/grpcserver_extension/support/streams.jl b/test/flight/grpcserver_extension/support/streams.jl new file mode 100644 index 0000000..7d9da0a --- /dev/null +++ b/test/flight/grpcserver_extension/support/streams.jl @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_capture_server_stream(grpcserver, ::Type{T}) where {T} + messages = T[] + closed = Ref(false) + stream = grpcserver.ServerStream{T}( + (message, compress) -> begin + @test compress + push!(messages, message) + end, + () -> (closed[] = true), + ) + return messages, closed, stream +end + +function grpcserver_capture_bidi_stream( + grpcserver, + ::Type{Request}, + ::Type{Response}, + requests, +) where {Request,Response} + messages = Response[] + closed = Ref(false) + stream = grpcserver.BidiStream{Request,Response}( + FlightTestSupport.next_message_factory(requests), + (message, compress) -> begin + @test compress + push!(messages, message) + end, + () -> (closed[] = true), + () -> false, + ) + return messages, closed, stream +end diff --git a/test/flight/grpcserver_extension/unary_tests.jl b/test/flight/grpcserver_extension/unary_tests.jl new file mode 100644 index 0000000..2f005a1 --- /dev/null +++ b/test/flight/grpcserver_extension/unary_tests.jl @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_unary(grpcserver, service, fixture, metadata) + grpc_descriptor = grpcserver.service_descriptor(service) + + unary_context = grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/GetFlightInfo"; + metadata=metadata, + ) + schema_context = grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/GetSchema"; + metadata=metadata, + ) + + direct_info = + grpc_descriptor.methods["GetFlightInfo"].handler(unary_context, fixture.descriptor) + @test direct_info.total_records == 3 + @test direct_info.endpoint[1].ticket.ticket == fixture.ticket.ticket + + direct_schema = + grpc_descriptor.methods["GetSchema"].handler(schema_context, fixture.descriptor) + @test Arrow.Flight.schemaipc(direct_schema) == fixture.schema_bytes +end diff --git a/test/flight/handshake_interop.jl b/test/flight/handshake_interop.jl new file mode 100644 index 0000000..0ef2711 --- /dev/null +++ b/test/flight/handshake_interop.jl @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight handshake interop" begin + server = FlightTestSupport.start_handshake_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + + handshake_req, handshake_request, handshake_response = + Arrow.Flight.handshake(client) + put!(handshake_request, protocol.HandshakeRequest(UInt64(0), b"test")) + put!(handshake_request, protocol.HandshakeRequest(UInt64(0), b"p4ssw0rd")) + close(handshake_request) + + handshake_messages = collect(handshake_response) + gRPCClient.grpc_async_await(handshake_req) + + @test length(handshake_messages) == 1 + @test handshake_messages[1].protocol_version == 0 + @test handshake_messages[1].payload == b"secret:test" + + token_client = Arrow.Flight.withtoken(client, handshake_messages[1].payload) + actions_req, actions_channel = Arrow.Flight.listactions(token_client) + actions = collect(actions_channel) + gRPCClient.grpc_async_await(actions_req) + @test actions == + [protocol.ActionType("authenticated", "Requires a valid auth token")] + + auth_client, auth_messages = + Arrow.Flight.authenticate(client, "test", "p4ssw0rd") + @test length(auth_messages) == 1 + @test auth_messages[1].protocol_version == + handshake_messages[1].protocol_version + @test auth_messages[1].payload == handshake_messages[1].payload + @test auth_client.headers == ["auth-token-bin" => b"secret:test"] + + bad_req, bad_request, bad_response = Arrow.Flight.handshake(client) + put!(bad_request, protocol.HandshakeRequest(UInt64(0), b"test")) + put!(bad_request, protocol.HandshakeRequest(UInt64(0), b"wrong")) + close(bad_request) + + @test isempty(collect(bad_response)) + @test_throws gRPCClient.gRPCServiceCallException gRPCClient.grpc_async_await( + bad_req, + ) + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/header_interop.jl b/test/flight/header_interop.jl new file mode 100644 index 0000000..a2f35a6 --- /dev/null +++ b/test/flight/header_interop.jl @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight header interop" begin + server = FlightTestSupport.start_headers_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + + try + FlightTestSupport.with_test_grpc_handle() do grpc + base_client = + Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + client = Arrow.Flight.withheaders( + base_client, + "authorization" => "Bearer token1234", + ) + + actions_req, actions_channel = Arrow.Flight.listactions(client) + actions = collect(actions_channel) + gRPCClient.grpc_async_await(actions_req) + @test actions == [ + protocol.ActionType( + "echo-authorization", + "Return the Authorization header", + ), + ] + + action_req, action_channel = Arrow.Flight.doaction( + client, + protocol.Action("echo-authorization", UInt8[]), + ) + action_results = collect(action_channel) + gRPCClient.grpc_async_await(action_req) + @test length(action_results) == 1 + @test String(action_results[1].body) == "Bearer token1234" + + call_req, call_channel = Arrow.Flight.doaction( + base_client, + protocol.Action("echo-authorization", UInt8[]); + headers=["authorization" => "Bearer call-level"], + ) + call_results = collect(call_channel) + gRPCClient.grpc_async_await(call_req) + @test length(call_results) == 1 + @test String(call_results[1].body) == "Bearer call-level" + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/ipc_conversion.jl b/test/flight/ipc_conversion.jl new file mode 100644 index 0000000..285e274 --- /dev/null +++ b/test/flight/ipc_conversion.jl @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +using DataAPI +using Tables +using UUIDs + +@testset "Flight IPC conversion helpers" begin + missing_schema_fragment = "the server may have terminated the stream before emitting the first schema-bearing FlightData message" + descriptor = Arrow.Flight.Protocol.FlightDescriptor( + Arrow.Flight.Protocol.var"FlightDescriptor.DescriptorType".PATH, + UInt8[], + ["datasets", "roundtrip"], + ) + source = Tables.partitioner(( + (id=Int64[1, 2], label=["one", "two"]), + (id=Int64[3], label=["three"]), + )) + messages = Arrow.Flight.flightdata(source; descriptor=descriptor) + + @test !isempty(messages) + @test messages[1].flight_descriptor == descriptor + @test all(isnothing(msg.flight_descriptor) for msg in messages[2:end]) + @test !isempty(messages[1].data_header) + @test isempty(messages[1].data_body) + + bytes = Arrow.Flight.streambytes(messages) + @test Arrow.readbuffer(bytes, 1, UInt32) == Arrow.CONTINUATION_INDICATOR_BYTES + @test Arrow.readbuffer(bytes, length(bytes) - 3, Int32) == 0 + + batches = collect(Arrow.Flight.stream(messages)) + @test length(batches) == 2 + @test batches[1].id == [1, 2] + @test batches[2].label == ["three"] + + tbl = Arrow.Flight.table(messages) + @test tbl.id == [1, 2, 3] + @test tbl.label == ["one", "two", "three"] + + schema_bytes = Arrow.Flight.schemaipc(first(messages)) + @test Arrow.Flight.schemaipc(Arrow.Flight.Protocol.SchemaResult(schema_bytes[5:end])) == + schema_bytes + + stream_error = try + Arrow.Flight.stream(Arrow.Flight.Protocol.FlightData[]) + nothing + catch err + err + end + @test stream_error isa ArgumentError + @test occursin(missing_schema_fragment, sprint(showerror, stream_error)) + + table_error = try + Arrow.Flight.table(Arrow.Flight.Protocol.FlightData[]) + nothing + catch err + err + end + @test table_error isa ArgumentError + @test occursin(missing_schema_fragment, sprint(showerror, table_error)) + + empty_tbl = Arrow.Flight.table( + Arrow.Flight.Protocol.FlightData[]; + schema=Arrow.Flight.Protocol.SchemaResult(schema_bytes[5:end]), + ) + @test isempty(empty_tbl.id) + @test isempty(empty_tbl.label) + + metadata_source = Tables.partitioner(((title=["red", "blue"],), (title=["green"],))) + metadata_messages = Arrow.Flight.flightdata( + metadata_source; + metadata=Dict("dataset" => "flight"), + colmetadata=Dict(:title => Dict("lang" => "en")), + ) + metadata_schema_bytes = Arrow.Flight.schemaipc(first(metadata_messages)) + metadata_info = Arrow.Flight.Protocol.FlightInfo( + metadata_schema_bytes[5:end], + nothing, + Arrow.Flight.Protocol.FlightEndpoint[], + Int64(-1), + Int64(-1), + false, + UInt8[], + ) + metadata_batches = + collect(Arrow.Flight.stream(metadata_messages[2:end]; schema=metadata_info)) + metadata_table = Arrow.Flight.table(metadata_messages[2:end]; schema=metadata_info) + + @test length(metadata_batches) == 2 + @test DataAPI.metadata(metadata_batches[1], "dataset") == "flight" + @test DataAPI.colmetadata(metadata_batches[1], :title, "lang") == "en" + @test DataAPI.metadata(metadata_batches[2], "dataset") == "flight" + @test DataAPI.colmetadata(metadata_batches[2], :title, "lang") == "en" + @test metadata_table.title == ["red", "blue", "green"] + @test DataAPI.metadata(metadata_table, "dataset") == "flight" + @test DataAPI.colmetadata(metadata_table, :title, "lang") == "en" + metadata_parts = collect(Tables.partitions(metadata_table)) + @test length(metadata_parts) == 2 + @test metadata_parts[1].title == ["red", "blue"] + @test metadata_parts[2].title == ["green"] + @test DataAPI.metadata(metadata_parts[1], "dataset") == "flight" + @test DataAPI.colmetadata(metadata_parts[1], :title, "lang") == "en" + @test DataAPI.metadata(metadata_parts[2], "dataset") == "flight" + @test DataAPI.colmetadata(metadata_parts[2], :title, "lang") == "en" + + app_metadata_messages = Arrow.Flight.flightdata( + metadata_source; + metadata=Dict("dataset" => "flight"), + colmetadata=Dict(:title => Dict("lang" => "en")), + app_metadata=("batch:0", "batch:1"), + ) + metadata_batches_with_app = + collect(Arrow.Flight.stream(app_metadata_messages; include_app_metadata=true)) + metadata_table_with_app = + Arrow.Flight.table(app_metadata_messages; include_app_metadata=true) + @test length(metadata_batches_with_app) == 2 + @test metadata_batches_with_app[1].table.title == ["red", "blue"] + @test metadata_batches_with_app[2].table.title == ["green"] + @test String(metadata_batches_with_app[1].app_metadata) == "batch:0" + @test String(metadata_batches_with_app[2].app_metadata) == "batch:1" + @test metadata_table_with_app.table.title == ["red", "blue", "green"] + @test String.(metadata_table_with_app.app_metadata) == ["batch:0", "batch:1"] + + wrapped_metadata_source = Arrow.Flight.withappmetadata( + metadata_source; + app_metadata=("wrapped:0", "wrapped:1"), + ) + wrapped_metadata_messages = Arrow.Flight.flightdata( + wrapped_metadata_source; + metadata=Dict("dataset" => "flight"), + colmetadata=Dict(:title => Dict("lang" => "en")), + ) + wrapped_metadata_table = + Arrow.Flight.table(wrapped_metadata_messages; include_app_metadata=true) + @test wrapped_metadata_table.table.title == ["red", "blue", "green"] + @test String.(wrapped_metadata_table.app_metadata) == ["wrapped:0", "wrapped:1"] + + reemitted_channel = Channel{Arrow.Flight.Protocol.FlightData}(8) + reemit_task = @async Arrow.Flight.putflightdata!( + reemitted_channel, + Arrow.Flight.withappmetadata( + metadata_table_with_app.table; + app_metadata=("batch:0", "batch:1"), + ); + close=true, + ) + reemitted_messages = collect(reemitted_channel) + wait(reemit_task) + @test String.(getfield.(reemitted_messages[2:end], :app_metadata)) == + ["batch:0", "batch:1"] + reemitted_table = Arrow.Flight.table(reemitted_messages) + @test reemitted_table.title == metadata_table.title + @test DataAPI.metadata(reemitted_table, "dataset") == "flight" + @test DataAPI.colmetadata(reemitted_table, :title, "lang") == "en" + + app_metadata_error = try + Arrow.Flight.flightdata(metadata_source; app_metadata=("only-one",)) + nothing + catch err + err + end + @test app_metadata_error isa ArgumentError + @test occursin("app_metadata was exhausted", sprint(showerror, app_metadata_error)) + + duplicate_app_metadata_error = try + Arrow.Flight.flightdata( + Arrow.Flight.withappmetadata( + metadata_source; + app_metadata=("wrapped:0", "wrapped:1"), + ); + app_metadata=("extra:0", "extra:1"), + ) + nothing + catch err + err + end + @test duplicate_app_metadata_error isa ArgumentError + @test occursin( + "Arrow.Flight.withappmetadata", + sprint(showerror, duplicate_app_metadata_error), + ) + + extension_source = ( + uuid=[UUID(UInt128(1)), UUID(UInt128(2))], + flag=[Arrow.Bool8(true), Arrow.Bool8(false)], + json=Union{Missing,Arrow.JSONText{String}}[Arrow.JSONText("{\"a\":1}"), missing], + ts=Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}}[ + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + Int16(-480), + ), + missing, + ], + ) + extension_messages = Arrow.Flight.flightdata(extension_source) + extension_batches = collect(Arrow.Flight.stream(extension_messages)) + extension_tbl = Arrow.Flight.table(extension_messages) + + @test Arrow.getmetadata(extension_batches[1].uuid)[Arrow.EXTENSION_NAME_KEY] == + "arrow.uuid" + @test Arrow.getmetadata(extension_batches[1].flag)[Arrow.EXTENSION_NAME_KEY] == + "arrow.bool8" + @test Arrow.getmetadata(extension_batches[1].json)[Arrow.EXTENSION_NAME_KEY] == + "arrow.json" + @test Arrow.getmetadata(extension_batches[1].ts)[Arrow.EXTENSION_NAME_KEY] == + "arrow.timestamp_with_offset" + @test Arrow.getmetadata(extension_tbl.uuid)[Arrow.EXTENSION_NAME_KEY] == "arrow.uuid" + @test Arrow.getmetadata(extension_tbl.flag)[Arrow.EXTENSION_NAME_KEY] == "arrow.bool8" + @test Arrow.getmetadata(extension_tbl.json)[Arrow.EXTENSION_NAME_KEY] == "arrow.json" + @test Arrow.getmetadata(extension_tbl.ts)[Arrow.EXTENSION_NAME_KEY] == + "arrow.timestamp_with_offset" + @test copy(extension_tbl.uuid) == extension_source.uuid + @test Bool.(copy(extension_tbl.flag)) == Bool.(extension_source.flag) + @test isequal(copy(extension_tbl.json), extension_source.json) + @test isequal(copy(extension_tbl.ts), extension_source.ts) +end diff --git a/test/flight/ipc_schema_separation.jl b/test/flight/ipc_schema_separation.jl new file mode 100644 index 0000000..facd1e1 --- /dev/null +++ b/test/flight/ipc_schema_separation.jl @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight IPC schema separation" begin + source = Tables.partitioner(((word=["red", "blue"],), (word=["red", "green"],))) + messages = Arrow.Flight.flightdata(source; dictencode=true) + schema_bytes = Arrow.Flight.schemaipc(first(messages)) + info = Arrow.Flight.Protocol.FlightInfo( + schema_bytes[5:end], + nothing, + Arrow.Flight.Protocol.FlightEndpoint[], + Int64(-1), + Int64(-1), + false, + UInt8[], + ) + payload = messages[2:end] + + @test length(messages) >= 4 + @test Arrow.Flight.schemaipc(info) == schema_bytes + + batches = collect(Arrow.Flight.stream(payload; schema=info)) + @test length(batches) == 2 + @test isequal(batches[1].word, ["red", "blue"]) + @test isequal(batches[2].word, ["red", "green"]) + + tbl = Arrow.Flight.table(payload; schema=info) + @test isequal(tbl.word, ["red", "blue", "red", "green"]) +end diff --git a/test/flight/poll_interop.jl b/test/flight/poll_interop.jl new file mode 100644 index 0000000..b3a5cf1 --- /dev/null +++ b/test/flight/poll_interop.jl @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight poll interop" begin + server = FlightTestSupport.start_poll_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + initial_descriptor = protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "poll"], + ) + + first_poll = Arrow.Flight.pollflightinfo(client, initial_descriptor) + @test !isnothing(first_poll.info) + @test !isnothing(first_poll.flight_descriptor) + @test first_poll.flight_descriptor.path == ["interop", "poll", "retry"] + @test first_poll.info.total_records == 1 + @test first_poll.info.ordered + @test first_poll.progress ≈ 0.5 + @test Arrow.Flight.schemaipc(first_poll.info) == Arrow.Flight.schemaipc( + protocol.SchemaResult(first_poll.info.schema[5:end]), + ) + + second_poll = + Arrow.Flight.pollflightinfo(client, first_poll.flight_descriptor) + @test !isnothing(second_poll.info) + @test isnothing(second_poll.flight_descriptor) + @test second_poll.info.flight_descriptor.path == ["interop", "poll"] + @test second_poll.progress ≈ 1.0 + @test length(second_poll.info.endpoint) == 1 + @test second_poll.info.endpoint[1].ticket.ticket == b"poll-ticket" + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/pyarrow_interop.jl b/test/flight/pyarrow_interop.jl new file mode 100644 index 0000000..ebf0e0f --- /dev/null +++ b/test/flight/pyarrow_interop.jl @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("pyarrow_interop/support.jl") +include("pyarrow_interop/discovery_tests.jl") +include("pyarrow_interop/download_tests.jl") +include("pyarrow_interop/upload_tests.jl") +include("pyarrow_interop/exchange_tests.jl") + +@testset "Flight pyarrow interop" begin + server = FlightTestSupport.start_pyarrow_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + descriptors = pyarrow_interop_descriptors(protocol) + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + pyarrow_interop_test_discovery(client, protocol, descriptors.download) + pyarrow_interop_test_download(client, descriptors.download) + pyarrow_interop_test_upload(client, descriptors.upload) + pyarrow_interop_test_exchange(client, descriptors.exchange) + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/pyarrow_interop/discovery_tests.jl b/test/flight/pyarrow_interop/discovery_tests.jl new file mode 100644 index 0000000..ee70af1 --- /dev/null +++ b/test/flight/pyarrow_interop/discovery_tests.jl @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_discovery(client, protocol, download_descriptor) + flights_req, flights_channel = Arrow.Flight.listflights(client) + flights = pyarrow_interop_collect(flights_req, flights_channel) + @test any( + info -> + !isnothing(info.flight_descriptor) && + info.flight_descriptor.path == download_descriptor.path, + flights, + ) + + actions_req, actions_channel = Arrow.Flight.listactions(client) + actions = pyarrow_interop_collect(actions_req, actions_channel) + @test any(action -> action.var"#type" == "ping", actions) + + action_req, action_channel = + Arrow.Flight.doaction(client, protocol.Action("ping", UInt8[])) + action_results = pyarrow_interop_collect(action_req, action_channel) + @test length(action_results) == 1 + @test String(action_results[1].body) == "pong" +end diff --git a/test/flight/pyarrow_interop/download_tests.jl b/test/flight/pyarrow_interop/download_tests.jl new file mode 100644 index 0000000..6b6b25a --- /dev/null +++ b/test/flight/pyarrow_interop/download_tests.jl @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_download(client, download_descriptor) + download_info = Arrow.Flight.getflightinfo(client, download_descriptor) + @test download_info.total_records == 3 + @test length(download_info.endpoint) == 1 + + download_schema = Arrow.Flight.getschema(client, download_descriptor) + @test Arrow.Flight.schemaipc(download_schema) == Arrow.Flight.schemaipc(download_info) + + doget_req, doget_channel = Arrow.Flight.doget(client, download_info.endpoint[1].ticket) + download_messages = pyarrow_interop_collect(doget_req, doget_channel) + + download_table = Arrow.Flight.table(download_messages; schema=download_info) + @test download_table.id == [1, 2, 3] + @test download_table.name == ["one", "two", "three"] +end diff --git a/test/flight/pyarrow_interop/exchange_tests.jl b/test/flight/pyarrow_interop/exchange_tests.jl new file mode 100644 index 0000000..7d212e1 --- /dev/null +++ b/test/flight/pyarrow_interop/exchange_tests.jl @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_exchange(client, exchange_descriptor) + exchange_source = Tables.partitioner(( + (id=Int64[21, 22], name=["twenty-one", "twenty-two"]), + (id=Int64[23], name=["twenty-three"]), + )) + exchange_metadata = Dict("dataset" => "interop-exchange") + exchange_colmetadata = Dict(:name => Dict("lang" => "en")) + exchange_app_metadata = ["client:0", "client:1"] + exchange_source = + Arrow.Flight.withappmetadata(exchange_source; app_metadata=exchange_app_metadata) + exchange_req, exchange_response = Arrow.Flight.doexchange( + client, + exchange_source; + descriptor=exchange_descriptor, + metadata=exchange_metadata, + colmetadata=exchange_colmetadata, + ) + exchanged_messages = Arrow.Flight.Protocol.FlightData[] + exchange_batches = collect( + Arrow.Flight.stream(( + (push!(exchanged_messages, message); message) for message in exchange_response + ),), + ) + gRPCClient.grpc_async_await(exchange_req) + + @test length(exchange_batches) == 2 + @test exchange_batches[1].id == [21, 22] + @test exchange_batches[1].name == ["twenty-one", "twenty-two"] + @test DataAPI.metadata(exchange_batches[1], "dataset") == "interop-exchange" + @test DataAPI.colmetadata(exchange_batches[1], :name, "lang") == "en" + @test exchange_batches[2].id == [23] + @test exchange_batches[2].name == ["twenty-three"] + @test DataAPI.metadata(exchange_batches[2], "dataset") == "interop-exchange" + @test DataAPI.colmetadata(exchange_batches[2], :name, "lang") == "en" + exchange_table = Arrow.Flight.table(exchanged_messages) + @test exchange_table.id == [21, 22, 23] + @test exchange_table.name == ["twenty-one", "twenty-two", "twenty-three"] + @test DataAPI.metadata(exchange_table, "dataset") == "interop-exchange" + @test DataAPI.colmetadata(exchange_table, :name, "lang") == "en" + @test filter(!isempty, getfield.(exchanged_messages, :app_metadata)) == + [b"client:0", b"client:1"] + + exchange_batches_with_app = + collect(Arrow.Flight.stream(exchanged_messages; include_app_metadata=true)) + @test exchange_batches_with_app[1].table.id == [21, 22] + @test exchange_batches_with_app[2].table.id == [23] + @test String.(getproperty.(exchange_batches_with_app, :app_metadata)) == + exchange_app_metadata + + exchange_table_with_app = + Arrow.Flight.table(exchanged_messages; include_app_metadata=true) + @test exchange_table_with_app.table.id == [21, 22, 23] + @test exchange_table_with_app.table.name == ["twenty-one", "twenty-two", "twenty-three"] + @test String.(exchange_table_with_app.app_metadata) == exchange_app_metadata +end diff --git a/test/flight/pyarrow_interop/support.jl b/test/flight/pyarrow_interop/support.jl new file mode 100644 index 0000000..893f65c --- /dev/null +++ b/test/flight/pyarrow_interop/support.jl @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_descriptors(protocol) + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + return ( + download=protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "download"], + ), + upload=protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "upload"], + ), + exchange=protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "exchange"], + ), + ) +end + +function pyarrow_interop_collect(req, channel) + messages = collect(channel) + gRPCClient.grpc_async_await(req) + return messages +end + +function pyarrow_interop_send_messages(req, request, response, messages) + for message in messages + put!(request, message) + end + close(request) + responses = collect(response) + gRPCClient.grpc_async_await(req) + return responses +end diff --git a/test/flight/pyarrow_interop/upload_tests.jl b/test/flight/pyarrow_interop/upload_tests.jl new file mode 100644 index 0000000..e07b8de --- /dev/null +++ b/test/flight/pyarrow_interop/upload_tests.jl @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_upload(client, upload_descriptor) + upload_source = Tables.partitioner(( + (id=Int64[10, 11], name=["ten", "eleven"]), + (id=Int64[12], name=["twelve"]), + )) + upload_metadata = Dict("dataset" => "interop-upload") + upload_colmetadata = Dict(:name => Dict("lang" => "en")) + doput_req, doput_response = Arrow.Flight.doput( + client, + upload_source; + descriptor=upload_descriptor, + metadata=upload_metadata, + colmetadata=upload_colmetadata, + ) + put_results = collect(doput_response) + gRPCClient.grpc_async_await(doput_req) + + @test !isempty(put_results) + @test String(put_results[end].app_metadata) == "stored" + + uploaded_info = Arrow.Flight.getflightinfo(client, upload_descriptor) + uploaded_req, uploaded_channel = + Arrow.Flight.doget(client, uploaded_info.endpoint[1].ticket) + uploaded_messages = pyarrow_interop_collect(uploaded_req, uploaded_channel) + + uploaded_table = Arrow.Flight.table(uploaded_messages; schema=uploaded_info) + @test uploaded_table.id == [10, 11, 12] + @test uploaded_table.name == ["ten", "eleven", "twelve"] + @test DataAPI.metadata(uploaded_table, "dataset") == "interop-upload" + @test DataAPI.colmetadata(uploaded_table, :name, "lang") == "en" +end diff --git a/test/flight/server_core.jl b/test/flight/server_core.jl new file mode 100644 index 0000000..3ca8e2d --- /dev/null +++ b/test/flight/server_core.jl @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("server_core/support.jl") +include("server_core/metadata_tests.jl") +include("server_core/descriptor_tests.jl") +include("server_core/direct_handler_tests.jl") +include("server_core/dispatch_tests.jl") + +@testset "Flight server core surface" begin + fixture = flight_server_core_fixture() + flight_server_core_test_metadata(fixture) + flight_server_core_test_descriptors(fixture) + flight_server_core_test_direct_handlers(fixture) + flight_server_core_test_dispatch(fixture) +end diff --git a/test/flight/server_core/descriptor_tests.jl b/test/flight/server_core/descriptor_tests.jl new file mode 100644 index 0000000..1cf2455 --- /dev/null +++ b/test/flight/server_core/descriptor_tests.jl @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_descriptors(fixture) + handshake_descriptor = Arrow.Flight.lookupmethod(fixture.descriptor_info, "Handshake") + @test !isnothing(handshake_descriptor) + @test handshake_descriptor.path == "/arrow.flight.protocol.FlightService/Handshake" + @test handshake_descriptor.request_streaming + @test handshake_descriptor.response_streaming + @test handshake_descriptor.request_type === fixture.protocol.HandshakeRequest + @test handshake_descriptor.response_type === fixture.protocol.HandshakeResponse + + doget_descriptor = Arrow.Flight.lookupmethod( + fixture.descriptor_info, + "/arrow.flight.protocol.FlightService/DoGet", + ) + @test !isnothing(doget_descriptor) + @test !doget_descriptor.request_streaming + @test doget_descriptor.response_streaming + @test doget_descriptor.request_type === fixture.protocol.Ticket + @test doget_descriptor.response_type === fixture.protocol.FlightData + @test isnothing(Arrow.Flight.lookupmethod(fixture.descriptor_info, "MissingMethod")) +end diff --git a/test/flight/server_core/direct_handler_tests.jl b/test/flight/server_core/direct_handler_tests.jl new file mode 100644 index 0000000..e936dad --- /dev/null +++ b/test/flight/server_core/direct_handler_tests.jl @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_direct_handlers(fixture) + @test_throws gRPCClient.gRPCServiceCallException Arrow.Flight.getflightinfo( + fixture.service, + fixture.context, + fixture.descriptor, + ) + + info = + Arrow.Flight.getflightinfo(fixture.implemented, fixture.context, fixture.descriptor) + @test info.total_records == 7 + @test info.total_bytes == 42 + @test info.flight_descriptor.path == ["server", "dataset"] + + get_response = Channel{fixture.protocol.FlightData}(1) + @test Arrow.Flight.doget( + fixture.implemented, + fixture.context, + fixture.protocol.Ticket(b"ticket-1"), + get_response, + ) == :doget_ok + @test length(collect(get_response)) == 1 + + actions_response = Channel{fixture.protocol.ActionType}(1) + @test Arrow.Flight.listactions( + fixture.implemented, + fixture.context, + actions_response, + ) == :listactions_ok + actions = collect(actions_response) + @test length(actions) == 1 + @test getfield(actions[1], Symbol("#type")) == "ping" + @test actions[1].description == "Ping action" +end diff --git a/test/flight/server_core/dispatch_tests.jl b/test/flight/server_core/dispatch_tests.jl new file mode 100644 index 0000000..28bc7b4 --- /dev/null +++ b/test/flight/server_core/dispatch_tests.jl @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_dispatch(fixture) + dispatch_info = Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + "/arrow.flight.protocol.FlightService/GetFlightInfo", + fixture.descriptor, + ) + @test dispatch_info.total_records == 7 + @test dispatch_info.flight_descriptor.path == ["server", "dataset"] + + doget_descriptor = Arrow.Flight.lookupmethod( + fixture.descriptor_info, + "/arrow.flight.protocol.FlightService/DoGet", + ) + get_response = Channel{fixture.protocol.FlightData}(1) + @test Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + doget_descriptor, + fixture.protocol.Ticket(b"ticket-1"), + get_response, + ) == :doget_ok + @test length(collect(get_response)) == 1 + + actions_response = Channel{fixture.protocol.ActionType}(1) + @test Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + "ListActions", + actions_response, + ) == :listactions_ok + @test length(collect(actions_response)) == 1 + @test_throws ArgumentError Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + "/arrow.flight.protocol.FlightService/MissingMethod", + fixture.descriptor, + ) +end diff --git a/test/flight/server_core/metadata_tests.jl b/test/flight/server_core/metadata_tests.jl new file mode 100644 index 0000000..3393792 --- /dev/null +++ b/test/flight/server_core/metadata_tests.jl @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_metadata(fixture) + @test Arrow.Flight.callheader(fixture.context, "authorization") == "Bearer test" + @test Arrow.Flight.callheader(fixture.context, "Authorization") == "Bearer test" + @test Arrow.Flight.callheader(fixture.context, "auth-token-bin") == UInt8[0x01, 0x02] + @test isnothing(Arrow.Flight.callheader(fixture.context, "missing")) + @test fixture.descriptor_info.name == "arrow.flight.protocol.FlightService" + @test length(fixture.descriptor_info.methods) == 10 +end diff --git a/test/flight/server_core/support.jl b/test/flight/server_core/support.jl new file mode 100644 index 0000000..3c3397a --- /dev/null +++ b/test/flight/server_core/support.jl @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_fixture() + protocol = Arrow.Flight.Protocol + context = Arrow.Flight.ServerCallContext( + headers=["authorization" => "Bearer test", "auth-token-bin" => UInt8[0x01, 0x02]], + peer="127.0.0.1:4000", + secure=true, + ) + descriptor_info = Arrow.Flight.servicedescriptor(Arrow.Flight.Service()) + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + descriptor = + protocol.FlightDescriptor(descriptor_type.PATH, UInt8[], ["server", "dataset"]) + service = Arrow.Flight.Service() + implemented = Arrow.Flight.Service( + getflightinfo=(ctx, req) -> begin + @test ctx === context + @test req.path == descriptor.path + return protocol.FlightInfo( + UInt8[], + req, + protocol.FlightEndpoint[], + 7, + 42, + false, + UInt8[], + ) + end, + doget=(ctx, ticket, response) -> begin + @test ctx === context + @test ticket.ticket == b"ticket-1" + put!(response, protocol.FlightData(nothing, UInt8[], UInt8[], UInt8[])) + close(response) + return :doget_ok + end, + listactions=(ctx, response) -> begin + @test ctx === context + put!(response, protocol.ActionType("ping", "Ping action")) + close(response) + return :listactions_ok + end, + ) + return (; protocol, context, descriptor_info, descriptor, service, implemented) +end diff --git a/test/flight/support.jl b/test/flight/support.jl new file mode 100644 index 0000000..d1cb8d3 --- /dev/null +++ b/test/flight/support.jl @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module FlightTestSupport + +using gRPCClient + +export PyArrowFlightServer, + flight_test_roots, + pyarrow_flight_python, + start_pyarrow_flight_server, + start_headers_flight_server, + start_handshake_flight_server, + start_poll_flight_server, + start_tls_flight_server, + stop_pyarrow_flight_server, + with_test_grpc_handle, + load_grpcserver, + generate_test_tls_certificate, + next_message_factory + +include("support/types.jl") +include("support/paths.jl") +include("support/python_servers.jl") +include("support/grpc.jl") +include("support/tls.jl") +include("support/streams.jl") + +end diff --git a/test/flight/support/grpc.jl b/test/flight/support/grpc.jl new file mode 100644 index 0000000..a9eb7cd --- /dev/null +++ b/test/flight/support/grpc.jl @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function with_test_grpc_handle(f::F) where {F} + grpc = gRPCClient.gRPCCURL() + gRPCClient.grpc_init(grpc) + try + return f(grpc) + finally + gRPCClient.grpc_shutdown(grpc) + end +end + +function load_grpcserver() + isnothing(Base.find_package("gRPCServer")) && return nothing + return Base.require( + Base.PkgId(Base.UUID("608c6337-0d7d-447f-bb69-0f5674ee3959"), "gRPCServer"), + ) +end diff --git a/test/flight/support/paths.jl b/test/flight/support/paths.jl new file mode 100644 index 0000000..0632072 --- /dev/null +++ b/test/flight/support/paths.jl @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const TEST_ROOT = normpath(joinpath(@__DIR__, "..", "..")) + +function git_toplevel(path::AbstractString) + try + cmd = pipeline( + Cmd(["git", "-C", path, "rev-parse", "--show-toplevel"]); + stderr=devnull, + ) + return normpath(chomp(read(cmd, String))) + catch + return nothing + end +end + +function flight_test_roots() + roots = String[] + path = abspath(TEST_ROOT) + while true + top = git_toplevel(path) + !isnothing(top) && push!(roots, top) + parent = dirname(path) + parent == path && break + path = parent + end + push!(roots, TEST_ROOT) + unique!(roots) + return roots +end + +function pyarrow_flight_python() + haskey(ENV, "ARROW_FLIGHT_PYTHON") && return ENV["ARROW_FLIGHT_PYTHON"] + cache_home = get(ENV, "PRJ_CACHE_HOME", ".cache") + for root in flight_test_roots() + python = joinpath(root, cache_home, "arrow-julia-flight-pyenv", "bin", "python") + isfile(python) && return python + end + return nothing +end diff --git a/test/flight/support/python_servers.jl b/test/flight/support/python_servers.jl new file mode 100644 index 0000000..07ba693 --- /dev/null +++ b/test/flight/support/python_servers.jl @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function start_python_flight_server( + script_name::AbstractString; + env_overrides::AbstractDict{<:AbstractString,<:AbstractString}=Dict{String,String}(), +) + python = pyarrow_flight_python() + isnothing(python) && return nothing + + stdout = Pipe() + stderr = Pipe() + env = merge( + Dict{String,String}(ENV), + Dict("PYTHONUNBUFFERED" => "1"), + Dict{String,String}(string(k) => string(v) for (k, v) in pairs(env_overrides)), + ) + cmd = setenv(Cmd([python, joinpath(TEST_ROOT, script_name)]), env) + process = run(pipeline(cmd; stdout=stdout, stderr=stderr), wait=false) + close(stdout.in) + close(stderr.in) + + line = try + readline(stdout) + catch err + errout = read(stderr, String) + wait(process) + error( + "failed to start pyarrow Flight server: $(sprint(showerror, err)); stderr=$(repr(errout))", + ) + end + port = parse(Int, chomp(line)) + return PyArrowFlightServer(process, stdout, stderr, port) +end + +start_pyarrow_flight_server() = start_python_flight_server("flight_pyarrow_server.py") +start_headers_flight_server() = start_python_flight_server("flight_headers_server.py") +start_handshake_flight_server() = start_python_flight_server("flight_handshake_server.py") +start_poll_flight_server() = start_python_flight_server("flight_poll_server.py") +function start_tls_flight_server(cert_path::AbstractString, key_path::AbstractString) + start_python_flight_server( + "flight_tls_server.py"; + env_overrides=Dict( + "ARROW_FLIGHT_TLS_CERT" => String(cert_path), + "ARROW_FLIGHT_TLS_KEY" => String(key_path), + ), + ) +end + +function stop_pyarrow_flight_server(server::PyArrowFlightServer) + try + kill(server.process) + catch + end + try + wait(server.process) + catch + end + close(server.stdout) + close(server.stderr) + return +end + +stop_pyarrow_flight_server(::Nothing) = nothing diff --git a/test/flight/support/streams.jl b/test/flight/support/streams.jl new file mode 100644 index 0000000..8b9d048 --- /dev/null +++ b/test/flight/support/streams.jl @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function next_message_factory(messages) + index = Ref(1) + return () -> begin + current = index[] + current > length(messages) && return nothing + index[] = current + 1 + return messages[current] + end +end diff --git a/test/flight/support/tls.jl b/test/flight/support/tls.jl new file mode 100644 index 0000000..ca7dfcf --- /dev/null +++ b/test/flight/support/tls.jl @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function generate_test_tls_certificate(dir::AbstractString) + openssl = Sys.which("openssl") + isnothing(openssl) && return nothing + + config_path = joinpath(dir, "openssl.cnf") + cert_path = joinpath(dir, "cert.pem") + key_path = joinpath(dir, "key.pem") + write( + config_path, + """ + [req] + distinguished_name = dn + x509_extensions = v3_req + prompt = no + + [dn] + CN = localhost + + [v3_req] + subjectAltName = @alt_names + + [alt_names] + DNS.1 = localhost + IP.1 = 127.0.0.1 + """, + ) + run( + Cmd([ + openssl, + "req", + "-x509", + "-nodes", + "-newkey", + "rsa:2048", + "-keyout", + key_path, + "-out", + cert_path, + "-days", + "1", + "-config", + config_path, + "-extensions", + "v3_req", + ]), + ) + return cert_path, key_path +end diff --git a/test/flight/support/types.jl b/test/flight/support/types.jl new file mode 100644 index 0000000..5cbca29 --- /dev/null +++ b/test/flight/support/types.jl @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +struct PyArrowFlightServer + process::Base.Process + stdout::Pipe + stderr::Pipe + port::Int +end diff --git a/test/flight/tls_interop.jl b/test/flight/tls_interop.jl new file mode 100644 index 0000000..7cdc3a7 --- /dev/null +++ b/test/flight/tls_interop.jl @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight TLS interop" begin + mktempdir() do dir + tls_material = FlightTestSupport.generate_test_tls_certificate(dir) + if isnothing(tls_material) + @test true + return + end + cert_path, key_path = tls_material + server = FlightTestSupport.start_tls_flight_server(cert_path, key_path) + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + descriptor = protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "tls", "download"], + ) + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client( + "grpc+tls://localhost:$(server.port)"; + grpc=grpc, + tls_root_certs=cert_path, + ) + info = Arrow.Flight.getflightinfo(client, descriptor) + @test info.total_records == 3 + @test length(info.endpoint) == 1 + + schema = Arrow.Flight.getschema(client, descriptor) + @test Arrow.Flight.schemaipc(schema) == Arrow.Flight.schemaipc(info) + + req, channel = Arrow.Flight.doget(client, info.endpoint[1].ticket) + messages = collect(channel) + gRPCClient.grpc_async_await(req) + + table = Arrow.Flight.table(messages; schema=info) + @test table.id == [31, 32, 33] + @test table.name == ["thirty-one", "thirty-two", "thirty-three"] + end + + FlightTestSupport.with_test_grpc_handle() do grpc + insecure_client = Arrow.Flight.Client( + "grpc+tls://localhost:$(server.port)"; + grpc=grpc, + disable_server_verification=true, + ) + info = Arrow.Flight.getflightinfo(insecure_client, descriptor) + @test info.total_records == 3 + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end + end +end diff --git a/test/flight_grpcserver.jl b/test/flight_grpcserver.jl new file mode 100644 index 0000000..7fd8c4c --- /dev/null +++ b/test/flight_grpcserver.jl @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +using Pkg + +const TEST_ROOT = @__DIR__ +const ARROW_ROOT = normpath(joinpath(TEST_ROOT, "..")) +const ARROWTYPES_ROOT = joinpath(ARROW_ROOT, "src", "ArrowTypes") + +function maybe_git_root(path::AbstractString) + try + return readchomp(pipeline(`git -C $path rev-parse --show-toplevel`; stderr=devnull)) + catch + return nothing + end +end + +function flight_grpcserver_roots(path::AbstractString) + roots = String[] + current = abspath(path) + while true + root = maybe_git_root(current) + if !isnothing(root) && root ∉ roots + push!(roots, root) + end + parent = dirname(current) + parent == current && break + current = parent + end + return roots +end + +function locate_grpcserver() + if haskey(ENV, "ARROW_FLIGHT_GRPCSERVER_PATH") + candidate = abspath(ENV["ARROW_FLIGHT_GRPCSERVER_PATH"]) + isdir(candidate) || error("ARROW_FLIGHT_GRPCSERVER_PATH does not exist: $candidate") + return candidate + end + for root in flight_grpcserver_roots(TEST_ROOT) + candidate = joinpath(root, ".cache", "vendor", "gRPCServer.jl") + isdir(candidate) && return candidate + end + error( + "Could not locate vendored gRPCServer.jl. " * + "Set ARROW_FLIGHT_GRPCSERVER_PATH to an explicit checkout path.", + ) +end + +const TEMP_ENV = mktempdir() +cp(joinpath(TEST_ROOT, "Project.toml"), joinpath(TEMP_ENV, "Project.toml")) + +Pkg.activate(TEMP_ENV) +Pkg.develop(PackageSpec(path=ARROW_ROOT)) +Pkg.develop(PackageSpec(path=ARROWTYPES_ROOT)) +Pkg.develop(PackageSpec(path=locate_grpcserver())) +Pkg.instantiate() + +using Test +using Arrow + +include(joinpath(TEST_ROOT, "flight.jl")) diff --git a/test/flight_handshake_server.py b/test/flight_handshake_server.py new file mode 100644 index 0000000..ab4f855 --- /dev/null +++ b/test/flight_handshake_server.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import signal + +import pyarrow.flight as fl + + +class TokenAuthHandler(fl.ServerAuthHandler): + def authenticate(self, outgoing, incoming): + username = incoming.read() + password = incoming.read() + if username == b"test" and password == b"p4ssw0rd": + outgoing.write(b"secret:test") + return + raise fl.FlightUnauthenticatedError("invalid username/password") + + def is_valid(self, token): + if token != b"secret:test": + raise fl.FlightUnauthenticatedError("invalid token") + return b"test" + + +class HandshakeFlightServer(fl.FlightServerBase): + def __init__(self): + super().__init__( + location="grpc://127.0.0.1:0", + auth_handler=TokenAuthHandler(), + ) + + def list_actions(self, context): + del context + return [fl.ActionType("authenticated", "Requires a valid auth token")] + + +def main(): + server = HandshakeFlightServer() + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_headers_server.py b/test/flight_headers_server.py new file mode 100644 index 0000000..39b86cc --- /dev/null +++ b/test/flight_headers_server.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import signal + +import pyarrow.flight as fl + + +def case_insensitive_header_lookup(headers, lookup_key): + lookup_key = lookup_key.lower() + for key, value in headers.items(): + if key.lower() == lookup_key: + return value + raise fl.FlightUnauthenticatedError(f"missing required header: {lookup_key}") + + +class HeaderEchoServerMiddlewareFactory(fl.ServerMiddlewareFactory): + def start_call(self, info, headers): + del info + authorization = case_insensitive_header_lookup(headers, "authorization") + return HeaderEchoServerMiddleware(authorization[0]) + + +class HeaderEchoServerMiddleware(fl.ServerMiddleware): + def __init__(self, authorization): + self.authorization = authorization + + +class HeaderEchoFlightServer(fl.FlightServerBase): + def __init__(self): + super().__init__( + location="grpc://127.0.0.1:0", + middleware={"auth": HeaderEchoServerMiddlewareFactory()}, + ) + + def list_actions(self, context): + del context + return [("echo-authorization", "Return the Authorization header")] + + def do_action(self, context, action): + if action.type != "echo-authorization": + raise KeyError(f"unsupported action: {action.type}") + middleware = context.get_middleware("auth") + if middleware is None: + raise fl.FlightUnauthenticatedError("missing auth middleware") + return [middleware.authorization.encode("utf-8")] + + +def main(): + server = HeaderEchoFlightServer() + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_poll_server.py b/test/flight_poll_server.py new file mode 100644 index 0000000..17d0392 --- /dev/null +++ b/test/flight_poll_server.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import pathlib +import signal +import sys +import tempfile +from concurrent import futures + +import grpc +import pyarrow as pa +import grpc_tools +from grpc_tools import protoc + + +ROOT = pathlib.Path(__file__).resolve().parent.parent / "src" / "flight" / "proto" +PROTO = ROOT / "Flight.proto" +GRPC_TOOLS_PROTO = pathlib.Path(grpc_tools.__file__).resolve().parent / "_proto" + + +def load_proto_modules(): + out = pathlib.Path(tempfile.mkdtemp(prefix="flight_poll_proto_")) + result = protoc.main( + [ + "grpc_tools.protoc", + f"-I{ROOT}", + f"-I{GRPC_TOOLS_PROTO}", + f"--python_out={out}", + f"--grpc_python_out={out}", + str(PROTO), + ] + ) + if result != 0: + raise RuntimeError(f"protoc failed with exit code {result}") + sys.path.insert(0, str(out)) + import Flight_pb2 + import Flight_pb2_grpc + + return Flight_pb2, Flight_pb2_grpc + + +def descriptor_key(descriptor): + return tuple(descriptor.path) + + +def main(): + pb2, pb2_grpc = load_proto_modules() + + class PollFlightInfoServicer(pb2_grpc.FlightServiceServicer): + def __init__(self, port): + self.pb2 = pb2 + self.port = port + self.schema_bytes = bytes(pa.schema([("id", pa.int64())]).serialize()) + + def _descriptor(self, path): + return self.pb2.FlightDescriptor( + type=self.pb2.FlightDescriptor.PATH, + path=list(path), + ) + + def _flight_info(self, path): + endpoint = self.pb2.FlightEndpoint( + ticket=self.pb2.Ticket(ticket=b"poll-ticket"), + location=[self.pb2.Location(uri=f"grpc://127.0.0.1:{self.port}")], + ) + return self.pb2.FlightInfo( + schema=self.schema_bytes, + flight_descriptor=self._descriptor(path), + endpoint=[endpoint], + total_records=1, + total_bytes=8, + ordered=True, + ) + + def PollFlightInfo(self, request, context): + del context + key = descriptor_key(request) + if key == ("interop", "poll"): + return self.pb2.PollInfo( + info=self._flight_info(key), + flight_descriptor=self._descriptor(("interop", "poll", "retry")), + progress=0.5, + ) + if key == ("interop", "poll", "retry"): + return self.pb2.PollInfo( + info=self._flight_info(("interop", "poll")), + progress=1.0, + ) + raise KeyError(f"unsupported poll descriptor: {key}") + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) + port = server.add_insecure_port("127.0.0.1:0") + pb2_grpc.add_FlightServiceServicer_to_server(PollFlightInfoServicer(port), server) + + def shutdown_handler(signum, frame): + del signum, frame + server.stop(grace=None) + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + server.start() + print(port, flush=True) + server.wait_for_termination() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_pyarrow_server.py b/test/flight_pyarrow_server.py new file mode 100644 index 0000000..b100469 --- /dev/null +++ b/test/flight_pyarrow_server.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import json +import signal +import sys + +import pyarrow as pa +import pyarrow.flight as fl + + +def normalize_component(value): + if isinstance(value, bytes): + return value.decode("utf-8") + return str(value) + + +def descriptor_key(descriptor): + if descriptor.descriptor_type != fl.DescriptorType.PATH: + raise KeyError("only PATH descriptors are supported") + return tuple(normalize_component(part) for part in descriptor.path) + + +def ticket_key(ticket): + return tuple(normalize_component(part) for part in json.loads(ticket.ticket.decode("utf-8"))) + + +def key_ticket(key): + return fl.Ticket(json.dumps(list(key)).encode("utf-8")) + + +class InteropFlightServer(fl.FlightServerBase): + def __init__(self): + super().__init__(location="grpc://127.0.0.1:0") + self._datasets = { + ("interop", "download"): pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "name": pa.array(["one", "two", "three"]), + } + ) + } + + def _descriptor(self, key): + return fl.FlightDescriptor.for_path(*key) + + def _flight_info(self, key): + table = self._datasets[key] + endpoint = fl.FlightEndpoint( + key_ticket(key), + [fl.Location.for_grpc_tcp("127.0.0.1", self.port)], + ) + return fl.FlightInfo( + table.schema, + self._descriptor(key), + [endpoint], + total_records=table.num_rows, + total_bytes=table.nbytes, + ) + + def list_flights(self, context, criteria): + del context, criteria + for key in sorted(self._datasets): + yield self._flight_info(key) + + def get_flight_info(self, context, descriptor): + del context + return self._flight_info(descriptor_key(descriptor)) + + def get_schema(self, context, descriptor): + del context + return fl.SchemaResult(self._datasets[descriptor_key(descriptor)].schema) + + def do_get(self, context, ticket): + del context + table = self._datasets[ticket_key(ticket)] + return fl.GeneratorStream(table.schema, iter(table.to_batches(max_chunksize=2))) + + def do_put(self, context, descriptor, reader, writer): + del context + self._datasets[descriptor_key(descriptor)] = reader.read_all() + writer.write(b"stored") + + def do_exchange(self, context, descriptor, reader, writer): + del context + key = descriptor_key(descriptor) + if key != ("interop", "exchange"): + raise KeyError(f"unsupported exchange descriptor: {key}") + + writer.begin(reader.schema) + batch_index = 0 + while True: + try: + chunk = reader.read_chunk() + except StopIteration: + break + if chunk.data is None: + continue + metadata = chunk.app_metadata + if metadata is None: + metadata = pa.py_buffer(f"exchange:{batch_index}".encode("utf-8")) + writer.write_with_metadata(chunk.data, metadata) + batch_index += 1 + + def list_actions(self, context): + del context + return [("ping", "Return a fixed pong payload")] + + def do_action(self, context, action): + del context + if action.type != "ping": + raise KeyError(f"unsupported action: {action.type}") + return [b"pong"] + + +def main(): + server = InteropFlightServer() + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_tls_server.py b/test/flight_tls_server.py new file mode 100644 index 0000000..c86de2b --- /dev/null +++ b/test/flight_tls_server.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import json +import os +import signal + +import pyarrow as pa +import pyarrow.flight as fl + + +def normalize_component(value): + if isinstance(value, bytes): + return value.decode("utf-8") + return str(value) + + +def descriptor_key(descriptor): + if descriptor.descriptor_type != fl.DescriptorType.PATH: + raise KeyError("only PATH descriptors are supported") + return tuple(normalize_component(part) for part in descriptor.path) + + +def key_ticket(key): + return fl.Ticket(json.dumps(list(key)).encode("utf-8")) + + +class TLSInteropFlightServer(fl.FlightServerBase): + def __init__(self, cert_path, key_path): + cert = open(cert_path, "rb").read() + key = open(key_path, "rb").read() + super().__init__( + location="grpc+tls://127.0.0.1:0", + tls_certificates=[fl.CertKeyPair(cert=cert, key=key)], + ) + self._datasets = { + ("interop", "tls", "download"): pa.table( + { + "id": pa.array([31, 32, 33], type=pa.int64()), + "name": pa.array(["thirty-one", "thirty-two", "thirty-three"]), + } + ) + } + + def _descriptor(self, key): + return fl.FlightDescriptor.for_path(*key) + + def _flight_info(self, key): + table = self._datasets[key] + endpoint = fl.FlightEndpoint( + key_ticket(key), + [fl.Location.for_grpc_tls("localhost", self.port)], + ) + return fl.FlightInfo( + table.schema, + self._descriptor(key), + [endpoint], + total_records=table.num_rows, + total_bytes=table.nbytes, + ) + + def get_flight_info(self, context, descriptor): + del context + return self._flight_info(descriptor_key(descriptor)) + + def get_schema(self, context, descriptor): + del context + return fl.SchemaResult(self._datasets[descriptor_key(descriptor)].schema) + + def do_get(self, context, ticket): + del context + key = tuple(normalize_component(part) for part in json.loads(ticket.ticket.decode("utf-8"))) + table = self._datasets[key] + return fl.GeneratorStream(table.schema, iter(table.to_batches(max_chunksize=2))) + + +def main(): + cert_path = os.environ["ARROW_FLIGHT_TLS_CERT"] + key_path = os.environ["ARROW_FLIGHT_TLS_KEY"] + server = TLSInteropFlightServer(cert_path, key_path) + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/run_end_encoded_small.arrow b/test/run_end_encoded_small.arrow new file mode 100644 index 0000000..17155c3 Binary files /dev/null and b/test/run_end_encoded_small.arrow differ diff --git a/test/runtests.jl b/test/runtests.jl index 315d1b6..166e429 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,7 @@ using CategoricalArrays using DataAPI using FilePathsBase using DataFrames +using JSON3 using OffsetArrays import Random: randstring using TestSetExtensions: ExtendedTestSet @@ -39,6 +40,7 @@ include(joinpath(@__DIR__, "testtables.jl")) include(joinpath(@__DIR__, "testappend.jl")) include(joinpath(@__DIR__, "integrationtest.jl")) include(joinpath(@__DIR__, "dates.jl")) +include(joinpath(@__DIR__, "flight.jl")) struct CustomStruct x::Int @@ -50,6 +52,14 @@ struct CustomStruct2{sym} x::Int end +module EnumRoundtripModule +@enum RankingStrategy lexical=1 semantic=2 hybrid=3 +end + +module WideEnumRoundtripModule +@enum WideRanking::UInt64 tiny=1 colossal=0xffffffffffffffff +end + @testset ExtendedTestSet "Arrow" begin @testset "table roundtrips" begin for case in testtables @@ -238,6 +248,26 @@ end @test Arrow.getmetadata(tt.col2)["colkey1"] == "colvalue1" @test Arrow.getmetadata(tt.col2)["colkey2"] == "colvalue2" @test Arrow.getmetadata(tt.col3)["colkey3"] == "colvalue3" + + source = Arrow.withmetadata( + (col1=collect(1:3), col2=["a", "b", "c"]); + metadata=["source" => "base"], + colmetadata=Dict(:col1 => ["semantic.role" => "left"]), + ) + overlay = Arrow.withmetadata( + source; + metadata=["overlay" => "yes"], + colmetadata=Dict( + :col1 => ["unit" => "count"], + :col2 => ["semantic.role" => "right"], + ), + ) + overlay_tt = Arrow.Table(Arrow.tobuffer(overlay)) + @test Arrow.getmetadata(overlay_tt)["source"] == "base" + @test Arrow.getmetadata(overlay_tt)["overlay"] == "yes" + @test Arrow.getmetadata(overlay_tt.col1)["semantic.role"] == "left" + @test Arrow.getmetadata(overlay_tt.col1)["unit"] == "count" + @test Arrow.getmetadata(overlay_tt.col2)["semantic.role"] == "right" end @testset "# custom compressors" begin @@ -263,6 +293,123 @@ end @test all(isequal.(values(t), values(tt))) end + @testset "View buffer count inference" begin + inline_len = Int32(Arrow.VIEW_INLINE_BYTES) + views = Arrow.ViewElement[ + Arrow.ViewElement(inline_len, Int32(0), Int32(0), Int32(0)), + Arrow.ViewElement(inline_len + Int32(148), Int32(0), Int32(0), Int32(0)), + Arrow.ViewElement(inline_len + Int32(207), Int32(0), Int32(1), Int32(160)), + ] + validity = Arrow.ValidityBitmap(UInt8[], 1, length(views), 0) + @test Arrow._viewisinline(inline_len) + @test !Arrow._viewisinline(inline_len + Int32(1)) + @test Arrow._viewbuffercount(validity, views, Int32(0)) == 2 + @test Arrow._viewbuffercount(validity, views, Int32(1)) == 2 + @test Arrow._viewbuffercount(validity, views, Int32(3)) == 3 + + sparse_validity = Arrow.ValidityBitmap(UInt8[0x05], 1, 3, 1) + sparse_views = Arrow.ViewElement[ + Arrow.ViewElement(inline_len + Int32(64), Int32(0), Int32(0), Int32(0)), + Arrow.ViewElement(inline_len + Int32(64), Int32(0), Int32(99), Int32(0)), + Arrow.ViewElement(inline_len, Int32(0), Int32(0), Int32(0)), + ] + @test !sparse_validity[2] + @test Arrow._viewbuffercount(sparse_validity, sparse_views, Int32(0)) == 1 + end + + @testset "single-partition tobuffer byte equivalence" begin + t = (col=OffsetArray(["a", "bc", "def"], 0:2),) + io = IOBuffer() + Arrow.write(io, t) + seekstart(io) + @test read(Arrow.tobuffer(t)) == read(io) + + tm = (col=OffsetArray(Union{Missing,String}["a", missing, "def"], 0:2),) + io = IOBuffer() + Arrow.write(io, tm) + seekstart(io) + @test read(Arrow.tobuffer(tm)) == read(io) + + bt = + (col=OffsetArray([codeunits("a"), codeunits("bc"), codeunits("def")], 0:2),) + io = IOBuffer() + Arrow.write(io, bt) + seekstart(io) + @test read(Arrow.tobuffer(bt)) == read(io) + + btm = ( + col=OffsetArray( + Union{Missing,Base.CodeUnits{UInt8,String}}[ + codeunits("a"), + missing, + codeunits("def"), + ], + 0:2, + ), + ) + io = IOBuffer() + Arrow.write(io, btm) + seekstart(io) + @test read(Arrow.tobuffer(btm)) == read(io) + + mapt = ( + col=OffsetArray([Dict("a" => 1, "b" => 2), Dict("a" => 3, "b" => 4)], 0:1), + ) + io = IOBuffer() + Arrow.write(io, mapt) + seekstart(io) + @test read(Arrow.tobuffer(mapt)) == read(io) + + nestedt = (col=OffsetArray([Int64[1, 2], Int64[3, 4], Int64[]], 0:2),) + io = IOBuffer() + Arrow.write(io, nestedt) + seekstart(io) + @test read(Arrow.tobuffer(nestedt)) == read(io) + + pooled = (col=PooledArray(["a", "b", "a", "c"]),) + io = IOBuffer() + Arrow.write(io, pooled; dictencode=true) + seekstart(io) + @test read(Arrow.tobuffer(pooled; dictencode=true)) == read(io) + + meta = Dict("key1" => "value1") + colmeta = Dict(:col => Dict("colkey1" => "colvalue1")) + io = IOBuffer() + Arrow.write(io, t; metadata=meta, colmetadata=colmeta) + seekstart(io) + @test read(Arrow.tobuffer(t; metadata=meta, colmetadata=colmeta)) == read(io) + + parts = Tables.partitioner([t, t]) + io = IOBuffer() + Arrow.write(io, parts) + seekstart(io) + @test read(Arrow.tobuffer(parts)) == read(io) + + string_missing_parts = Tables.partitioner([tm, tm]) + io = IOBuffer() + Arrow.write(io, string_missing_parts) + seekstart(io) + @test read(Arrow.tobuffer(string_missing_parts)) == read(io) + + binary_parts = Tables.partitioner([bt, bt]) + io = IOBuffer() + Arrow.write(io, binary_parts) + seekstart(io) + @test read(Arrow.tobuffer(binary_parts)) == read(io) + + binary_missing_parts = Tables.partitioner([btm, btm]) + io = IOBuffer() + Arrow.write(io, binary_missing_parts) + seekstart(io) + @test read(Arrow.tobuffer(binary_missing_parts)) == read(io) + + map_parts = Tables.partitioner([mapt, mapt]) + io = IOBuffer() + Arrow.write(io, map_parts) + seekstart(io) + @test read(Arrow.tobuffer(map_parts)) == read(io) + end + @testset "# 53" begin s = "a"^100 t = (a=[SubString(s, 1:10), SubString(s, 11:20)],) @@ -294,6 +441,38 @@ end @test isequal(tt.a, ['a', missing]) end + @testset "# offset bool write paths" begin + t = ( + a=OffsetArray(Bool[true, false, true], -1:1), + b=OffsetArray(Union{Missing,Bool}[true, missing, false], -1:1), + c=OffsetArray(Any[true, false, true], -1:1), + d=OffsetArray(Any[true, missing, false], -1:1), + ) + tt = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tt.c) == Bool + @test eltype(tt.d) == Union{Missing,Bool} + @test tt.a == Bool[true, false, true] + @test isequal(tt.b, Union{Missing,Bool}[true, missing, false]) + @test tt.c == Bool[true, false, true] + @test isequal(tt.d, Union{Missing,Bool}[true, missing, false]) + end + + @testset "# offset primitive write paths" begin + t = ( + a=OffsetArray(Int64[1, 2, 3], -1:1), + b=OffsetArray(Union{Missing,Int64}[1, missing, 3], -1:1), + c=OffsetArray(Any[1, 2, 3], -1:1), + d=OffsetArray(Any[1, missing, 3], -1:1), + ) + tt = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tt.c) == Int64 + @test eltype(tt.d) == Union{Missing,Int64} + @test tt.a == Int64[1, 2, 3] + @test isequal(tt.b, Union{Missing,Int64}[1, missing, 3]) + @test tt.c == Int64[1, 2, 3] + @test isequal(tt.d, Union{Missing,Int64}[1, missing, 3]) + end + @testset "# automatic custom struct serialization/deserialization" begin t = (col1=[CustomStruct(1, 2.3, "hey"), CustomStruct(4, 5.6, "there")],) @@ -306,6 +485,86 @@ end @test all(isequal.(values(t), values(tt))) end + @testset "# Julia Enum extension logical type roundtrip" begin + t = ( + col1=[EnumRoundtripModule.lexical, EnumRoundtripModule.hybrid], + col2=Union{Missing,EnumRoundtripModule.RankingStrategy}[ + missing, + EnumRoundtripModule.semantic, + ], + ) + + bytes = read(Arrow.tobuffer(t)) + tt = Arrow.Table(IOBuffer(bytes)) + raw = Arrow.Table(IOBuffer(bytes); convert=false) + + @test length(tt) == length(t) + @test eltype(tt.col1) == EnumRoundtripModule.RankingStrategy + @test eltype(tt.col2) == Union{Missing,EnumRoundtripModule.RankingStrategy} + @test tt.col1 == [EnumRoundtripModule.lexical, EnumRoundtripModule.hybrid] + @test isequal( + tt.col2, + Union{Missing,EnumRoundtripModule.RankingStrategy}[ + missing, + EnumRoundtripModule.semantic, + ], + ) + @test eltype(raw.col1) == Int32 + @test eltype(raw.col2) == Union{Missing,Int32} + @test raw.col1 == Int32[1, 3] + @test isequal(raw.col2, Union{Missing,Int32}[missing, 2]) + @test Arrow.getmetadata(tt.col1)["ARROW:extension:name"] == "JuliaLang.Enum" + @test occursin( + "Main.EnumRoundtripModule.RankingStrategy", + Arrow.getmetadata(tt.col1)["ARROW:extension:metadata"], + ) + end + + @testset "# Julia Enum extension contract edge cases" begin + t = ( + col=[WideEnumRoundtripModule.tiny, WideEnumRoundtripModule.colossal], + nullable=Union{Missing,WideEnumRoundtripModule.WideRanking}[ + missing, + WideEnumRoundtripModule.colossal, + ], + ) + bytes = read(Arrow.tobuffer(t)) + tt = Arrow.Table(IOBuffer(bytes)) + raw = Arrow.Table(IOBuffer(bytes); convert=false) + + @test eltype(tt.col) == WideEnumRoundtripModule.WideRanking + @test eltype(tt.nullable) == Union{Missing,WideEnumRoundtripModule.WideRanking} + @test tt.col == [WideEnumRoundtripModule.tiny, WideEnumRoundtripModule.colossal] + @test isequal( + tt.nullable, + Union{Missing,WideEnumRoundtripModule.WideRanking}[ + missing, + WideEnumRoundtripModule.colossal, + ], + ) + @test eltype(raw.col) == UInt64 + @test eltype(raw.nullable) == Union{Missing,UInt64} + @test raw.col == UInt64[1, typemax(UInt64)] + @test isequal(raw.nullable, Union{Missing,UInt64}[missing, typemax(UInt64)]) + + mismatch_metadata = "type=Main.WideEnumRoundtripModule.WideRanking;labels=tiny:1,colossal:2" + @test_logs (:warn, r"unsupported ARROW:extension:name type: \"JuliaLang.Enum\"") begin + mismatch_tt = Arrow.Table( + Arrow.tobuffer( + (col=UInt64[1, typemax(UInt64)],); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "JuliaLang.Enum", + "ARROW:extension:metadata" => mismatch_metadata, + ), + ), + ), + ) + @test eltype(mismatch_tt.col) == UInt64 + @test copy(mismatch_tt.col) == UInt64[1, typemax(UInt64)] + end + end + @testset "# 76" begin t = (col1=NamedTuple{(:a,),Tuple{Union{Int,String}}}[(a=1,), (a="x",)],) tt = Arrow.Table(Arrow.tobuffer(t)) @@ -328,6 +587,67 @@ end @test copy(tt.a) isa Vector{Nanosecond} @test copy(tt.b) isa Vector{UUID} @test copy(tt.c) isa Vector{Union{Missing,Nanosecond}} + @test Arrow.getmetadata(tt.b)["ARROW:extension:name"] == "arrow.uuid" + + legacy = ( + b=[ + Arrow.ArrowTypes.toarrow(UUID("550e8400-e29b-41d4-a716-446655440000")), + Arrow.ArrowTypes.toarrow(UUID("550e8400-e29b-41d4-a716-446655440001")), + ], + ) + legacy_tt = Arrow.Table( + Arrow.tobuffer( + legacy; + colmetadata=Dict( + :b => Dict("ARROW:extension:name" => "JuliaLang.UUID"), + ), + ), + ) + @test copy(legacy_tt.b) == [ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ] + + toffset = ( + b=OffsetArray( + [ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + -1:0, + ), + bm=OffsetArray( + Union{Missing,UUID}[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + missing, + ], + -1:0, + ), + ba=OffsetArray( + Any[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + -1:0, + ), + bam=OffsetArray( + Any[UUID("550e8400-e29b-41d4-a716-446655440000"), missing], + -1:0, + ), + ) + ttoffset = Arrow.Table(Arrow.tobuffer(toffset)) + @test collect(toffset.b) == ttoffset.b + @test isequal(collect(toffset.bm), ttoffset.bm) + @test eltype(ttoffset.ba) == NTuple{16,UInt8} + @test eltype(ttoffset.bam) == Union{Missing,NTuple{16,UInt8}} + @test map(Arrow.ArrowTypes.toarrow, collect(toffset.ba)) == copy(ttoffset.ba) + @test isequal( + map( + x -> ismissing(x) ? missing : Arrow.ArrowTypes.toarrow(x), + collect(toffset.bam), + ), + copy(ttoffset.bam), + ) end @testset "# copy on DictEncoding w/ missing values" begin @@ -353,6 +673,13 @@ end @test isa(first(av.indices), Signed) @test length(av) == 3 @test eltype(av) == String + + x = CategoricalArray(Union{Missing,String}["a", missing, "ccc"]) + tt = Arrow.Table(Arrow.tobuffer((x=x,); dictencode=true)) + @test isequal(collect(tt.x), collect(x)) + @test isequal(collect(copy(tt.x)), collect(x)) + df = DataFrame(tt; copycols=true) + @test isequal(collect(df.x), collect(x)) end @testset "# 120" begin @@ -463,6 +790,502 @@ end ) end + @testset "canonical timestamp_with_offset" begin + values = + Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}}[ + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577836800000, + ), + 330, + ), + missing, + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577923200000, + ), + -480, + ), + ] + @test ArrowTypes.JuliaType( + Val(Symbol("arrow.timestamp_with_offset")), + NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + }, + "", + ) == Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND} + tt = Arrow.Table(Arrow.tobuffer((col=values,))) + @test eltype(tt.col) == + Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}} + @test isequal(copy(tt.col), values) + @test Arrow.getmetadata(tt.col)["ARROW:extension:name"] == + "arrow.timestamp_with_offset" + + raw_tt = Arrow.Table(Arrow.tobuffer((col=values,)); convert=false) + @test eltype(raw_tt.col) == Union{ + Missing, + NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + }, + } + @test isequal( + copy(raw_tt.col), + Union{ + Missing, + NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + }, + }[ + ( + timestamp=Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577836800000, + ), + offset_minutes=Int16(330), + ), + missing, + ( + timestamp=Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577923200000, + ), + offset_minutes=Int16(-480), + ), + ], + ) + end + + @testset "Run-End Encoded read support" begin + path = joinpath(@__DIR__, "run_end_encoded_small.arrow") + expected = ["a", "a", "b", "b", "b"] + + tt = Arrow.Table(path) + @test tt isa Arrow.Table + @test eltype(tt.x) == Union{Missing,String} + @test collect(tt.x) == expected + @test copy(tt.x) == expected + + batches = collect(Arrow.Stream(path)) + @test length(batches) == 1 + @test collect(batches[1].x) == expected + + @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.tobuffer(tt) + @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.tobuffer(( + x=tt.x, + )) + end + + @testset "canonical bool8/json/opaque" begin + bools = + Union{Missing,Arrow.Bool8}[Arrow.Bool8(true), missing, Arrow.Bool8(false)] + @test ArrowTypes.JuliaType(Val(Symbol("arrow.bool8")), Int8, "") == Arrow.Bool8 + tt = Arrow.Table(Arrow.tobuffer((col=bools,))) + @test eltype(tt.col) == Union{Missing,Arrow.Bool8} + @test isequal(copy(tt.col), bools) + @test Arrow.getmetadata(tt.col)["ARROW:extension:name"] == "arrow.bool8" + + raw_tt = Arrow.Table(Arrow.tobuffer((col=bools,)); convert=false) + @test eltype(raw_tt.col) == Union{Missing,Int8} + @test isequal(copy(raw_tt.col), Union{Missing,Int8}[1, missing, 0]) + + jsons = Union{Missing,Arrow.JSONText{String}}[ + Arrow.JSONText("{\"a\":1}"), + missing, + Arrow.JSONText("[1,2,3]"), + ] + @test ArrowTypes.JuliaType(Val(Symbol("arrow.json")), String, "") == + Arrow.JSONText{String} + json_tt = Arrow.Table(Arrow.tobuffer((col=jsons,))) + @test eltype(json_tt.col) == Union{Missing,Arrow.JSONText{String}} + @test isequal(copy(json_tt.col), jsons) + @test Arrow.getmetadata(json_tt.col)["ARROW:extension:name"] == "arrow.json" + + raw_json_tt = Arrow.Table(Arrow.tobuffer((col=jsons,)); convert=false) + @test eltype(raw_json_tt.col) == Union{Missing,String} + @test isequal( + copy(raw_json_tt.col), + Union{Missing,String}["{\"a\":1}", missing, "[1,2,3]"], + ) + + opaque_meta = Arrow.opaquemetadata("pkg.Type", "vendor.example") + @test ArrowTypes.JuliaType(Val(Symbol("arrow.opaque")), String, opaque_meta) == + String + opaque_tt = Arrow.Table( + Arrow.tobuffer( + (col=["a", "b"],); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.opaque", + "ARROW:extension:metadata" => opaque_meta, + ), + ), + ), + ) + @test eltype(opaque_tt.col) == String + @test copy(opaque_tt.col) == ["a", "b"] + @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:name"] == "arrow.opaque" + @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:metadata"] == + opaque_meta + end + + @testset "canonical advanced passthrough" begin + function assert_canonical_extension_error(f::Function, needle::AbstractString) + err = try + f() + nothing + catch e + e + end + @test err !== nothing + @test occursin(needle, sprint(showerror, err)) + return + end + + @test Arrow.variantmetadata() == "" + + fixed_metadata = Arrow.fixedshapetensormetadata( + [2, 2]; + dim_names=["x", "y"], + permutation=[1, 0], + ) + @test JSON3.read(fixed_metadata)["shape"] == [2, 2] + @test JSON3.read(fixed_metadata)["dim_names"] == ["x", "y"] + @test JSON3.read(fixed_metadata)["permutation"] == [1, 0] + + variable_metadata = Arrow.variableshapetensormetadata( + uniform_shape=Union{Nothing,Int}[2], + dim_names=["axis0"], + permutation=[0], + ) + @test ArrowTypes.JuliaType(Val(Symbol("arrow.parquet.variant")), String, "") == + String + @test ArrowTypes.JuliaType( + Val(Symbol("arrow.fixed_shape_tensor")), + NTuple{4,Int32}, + fixed_metadata, + ) == NTuple{4,Int32} + @test ArrowTypes.JuliaType( + Val(Symbol("arrow.variable_shape_tensor")), + NamedTuple{(:data, :shape),Tuple{Vector{Int32},NTuple{1,Int32}}}, + variable_metadata, + ) == NamedTuple{(:data, :shape),Tuple{Vector{Int32},NTuple{1,Int32}}} + @test JSON3.read(variable_metadata)["uniform_shape"] == [2] + @test JSON3.read(variable_metadata)["dim_names"] == ["axis0"] + @test JSON3.read(variable_metadata)["permutation"] == [0] + @test Arrow.variableshapetensormetadata() == "" + + @test_throws ArgumentError Arrow.fixedshapetensormetadata( + [2, 2]; + dim_names=["x"], + ) + @test_throws ArgumentError Arrow.variableshapetensormetadata( + uniform_shape=Union{Nothing,Int}[2, nothing]; + permutation=[0], + ) + + variant_values = + Union{Missing,NamedTuple{(:metadata, :value),Tuple{String,String}}}[ + (metadata="json", value="{\"a\":1}"), + missing, + (metadata="str", value="abc"), + ] + @test_logs min_level=Base.CoreLogging.Warn begin + variant_tt = Arrow.Table( + Arrow.tobuffer( + (col=variant_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.parquet.variant", + "ARROW:extension:metadata" => Arrow.variantmetadata(), + ), + ), + ), + ) + @test eltype(variant_tt.col) == eltype(variant_values) + @test isequal(copy(variant_tt.col), variant_values) + @test Arrow.getmetadata(variant_tt.col)["ARROW:extension:name"] == + "arrow.parquet.variant" + end + + fixed_tensor_values = Union{Missing,NTuple{4,Int32}}[ + (Int32(1), Int32(2), Int32(3), Int32(4)), + missing, + (Int32(5), Int32(6), Int32(7), Int32(8)), + ] + @test_logs min_level=Base.CoreLogging.Warn begin + fixed_tensor_tt = Arrow.Table( + Arrow.tobuffer( + (col=fixed_tensor_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.fixed_shape_tensor", + "ARROW:extension:metadata" => fixed_metadata, + ), + ), + ), + ) + @test eltype(fixed_tensor_tt.col) == eltype(fixed_tensor_values) + @test isequal(copy(fixed_tensor_tt.col), fixed_tensor_values) + @test Arrow.getmetadata(fixed_tensor_tt.col)["ARROW:extension:name"] == + "arrow.fixed_shape_tensor" + end + + variable_tensor_values = Union{ + Missing, + NamedTuple{(:data, :shape),Tuple{Vector{Int32},NTuple{1,Int32}}}, + }[ + (data=Int32[1, 2, 3, 4], shape=(Int32(2),)), + missing, + (data=Int32[5, 6], shape=(Int32(1),)), + ] + @test_logs min_level=Base.CoreLogging.Warn begin + variable_tensor_tt = Arrow.Table( + Arrow.tobuffer( + (col=variable_tensor_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.variable_shape_tensor", + "ARROW:extension:metadata" => variable_metadata, + ), + ), + ), + ) + @test eltype(variable_tensor_tt.col) == eltype(variable_tensor_values) + @test isequal( + map( + x -> x === missing ? missing : (data=copy(x.data), shape=x.shape), + copy(variable_tensor_tt.col), + ), + variable_tensor_values, + ) + @test Arrow.getmetadata(variable_tensor_tt.col)["ARROW:extension:name"] == + "arrow.variable_shape_tensor" + end + + invalid_variant_bytes = Arrow.tobuffer( + (col=variant_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.parquet.variant", + "ARROW:extension:metadata" => "{\"unexpected\":true}", + ), + ), + ) + assert_canonical_extension_error( + () -> Arrow.Table(invalid_variant_bytes), + "invalid canonical arrow.parquet.variant extension", + ) + + invalid_fixed_bytes = Arrow.tobuffer( + (col=fixed_tensor_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.fixed_shape_tensor", + "ARROW:extension:metadata" => + Arrow.fixedshapetensormetadata([3, 2]), + ), + ), + ) + assert_canonical_extension_error( + () -> Arrow.Table(invalid_fixed_bytes), + "invalid canonical arrow.fixed_shape_tensor extension", + ) + + invalid_variable_bytes = Arrow.tobuffer( + (col=["a", "b"],); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.variable_shape_tensor", + "ARROW:extension:metadata" => + Arrow.variableshapetensormetadata( + uniform_shape=Union{Nothing,Int}[1], + ), + ), + ), + ) + assert_canonical_extension_error( + () -> Arrow.Table(invalid_variable_bytes), + "invalid canonical arrow.variable_shape_tensor extension", + ) + end + + @testset "logical extension runtime contract" begin + uuid = UUID("550e8400-e29b-41d4-a716-446655440000") + @test Arrow._builtinarrowtype(UUID) == NTuple{16,UInt8} + @test Arrow._builtintoarrow(uuid) == + ArrowTypes._cast(NTuple{16,UInt8}, uuid.value) + @test Arrow._builtinarrowname(UUID) == Symbol("arrow.uuid") + @test ArrowTypes.ArrowType(UUID) == Arrow._builtinarrowtype(UUID) + @test ArrowTypes.toarrow(uuid) == Arrow._builtintoarrow(uuid) + @test ArrowTypes.arrowname(UUID) == Arrow._builtinarrowname(UUID) + @test ArrowTypes.JuliaType(Val(Symbol("arrow.uuid"))) == UUID + @test ArrowTypes.JuliaType(Val(Symbol("JuliaLang.UUID"))) == UUID + uuid_spec = Arrow._extensionspec(UUID) + @test uuid_spec isa Arrow.ExtensionTypeSpec + @test uuid_spec.name == Arrow.ArrowTypes.UUIDSYMBOL + @test uuid_spec.metadata == "" + @test Arrow._resolveextensionjuliatype( + Arrow.ExtensionTypeSpec(Arrow.ArrowTypes.LEGACY_UUIDSYMBOL, ""), + NTuple{16,UInt8}, + ) == UUID + + bool8_spec = Arrow._extensionspec(Arrow.Bool8) + @test bool8_spec isa Arrow.ExtensionTypeSpec + @test bool8_spec.name == Symbol("arrow.bool8") + @test Arrow._builtinarrowtype(Arrow.Bool8) == Int8 + @test Arrow._builtintoarrow(Arrow.Bool8(true)) == Int8(1) + @test Arrow._builtinarrowname(Arrow.Bool8) == Symbol("arrow.bool8") + @test Arrow._builtinfromarrow(Arrow.Bool8, Int8(1)) == Arrow.Bool8(true) + @test Arrow._builtindefault(Arrow.Bool8) == Arrow.Bool8(false) + @test Arrow._resolveextensionjuliatype(bool8_spec, Int8) == Arrow.Bool8 + + @test Arrow._builtinarrowtype(Arrow.JSONText{String}) == String + @test Arrow._builtintoarrow(Arrow.JSONText("abc")) == "abc" + @test Arrow._builtinarrowname(Arrow.JSONText{String}) == Symbol("arrow.json") + @test Arrow._builtinfromarrow(Arrow.JSONText{String}, pointer("abc"), 3) == + Arrow.JSONText("abc") + @test Arrow._builtinfromarrow(Arrow.JSONText{String}, "xyz") == + Arrow.JSONText("xyz") + @test Arrow._builtindefault(Arrow.JSONText{String}) == Arrow.JSONText("") + + timestamp_storage = NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + } + zdt = ZonedDateTime(Dates.DateTime(2020), tz"Europe/Paris") + @test Arrow._builtinarrowtype(ZonedDateTime) == Arrow.Timestamp + @test Arrow._builtintoarrow(zdt) == convert( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,Symbol("Europe/Paris")}, + zdt, + ) + @test Arrow._builtinarrowname(ZonedDateTime) == + Symbol("JuliaLang.ZonedDateTime-UTC") + paris_timestamp = + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,Symbol("Europe/Paris")}(0) + @test Arrow._builtinfromarrow(ZonedDateTime, paris_timestamp) == + convert(ZonedDateTime, paris_timestamp) + @test Arrow._builtindefault(ZonedDateTime) == + ZonedDateTime(1, 1, 1, 1, 1, 1, tz"UTC") + @test Arrow._builtinarrowname( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == Symbol("arrow.timestamp_with_offset") + @test Arrow._builtinarrowtype( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + } + ts_with_offset = Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + Int16(-480), + ) + @test Arrow._builtintoarrow(ts_with_offset) == ( + timestamp=Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + offset_minutes=Int16(-480), + ) + @test ArrowTypes.ArrowType( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == Arrow._builtinarrowtype( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) + @test ArrowTypes.toarrow(ts_with_offset) == + Arrow._builtintoarrow(ts_with_offset) + @test Arrow._builtindefault( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == zero(Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}) + @test Arrow._builtinfromarrowstruct( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + Val((:timestamp, :offset_minutes)), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + Int16(-480), + ) == ts_with_offset + @test Arrow._resolveextensionjuliatype( + Arrow.ExtensionTypeSpec(Symbol("arrow.timestamp_with_offset"), ""), + timestamp_storage, + ) == Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND} + + opaque_spec = Arrow.ExtensionTypeSpec( + Symbol("arrow.opaque"), + Arrow.opaquemetadata("demo.type", "demo.vendor"), + ) + @test Arrow.opaquemetadata("demo.type", "demo.vendor") == + Arrow._builtinopaquemetadata("demo.type", "demo.vendor") + @test Arrow._resolveextensionjuliatype(opaque_spec, Vector{UInt8}) == + Vector{UInt8} + @test Arrow.variantmetadata() == Arrow._builtinvariantmetadata() + @test Arrow.fixedshapetensormetadata( + [2, 2]; + dim_names=["row", "col"], + permutation=[1, 0], + ) == Arrow._builtinfixedshapetensormetadata( + [2, 2]; + dim_names=["row", "col"], + permutation=[1, 0], + ) + @test Arrow.variableshapetensormetadata( + uniform_shape=[2, nothing]; + dim_names=["row", "col"], + permutation=[1, 0], + ) == Arrow._builtinvariableshapetensormetadata( + uniform_shape=[2, nothing]; + dim_names=["row", "col"], + permutation=[1, 0], + ) + @test Arrow._builtinextensionjuliatype( + Val(Symbol("JuliaLang.ZonedDateTime-UTC")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == ZonedDateTime + @test ArrowTypes.JuliaType( + Val(Symbol("JuliaLang.ZonedDateTime-UTC")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == ZonedDateTime + @test Arrow._builtinextensionjuliatype( + Val(Symbol("JuliaLang.ZonedDateTime")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == Arrow.LocalZonedDateTime + @test ArrowTypes.JuliaType( + Val(Symbol("JuliaLang.ZonedDateTime")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == Arrow.LocalZonedDateTime + local_zdt_timestamp = + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,Symbol("Europe/Paris")}(0) + @test Arrow._builtinfromarrow(Arrow.LocalZonedDateTime, local_zdt_timestamp) == + ArrowTypes.fromarrow(Arrow.LocalZonedDateTime, local_zdt_timestamp) + + @test Arrow._resolveextensionjuliatype( + Arrow.ExtensionTypeSpec(Symbol("JuliaLang.ZonedDateTime"), ""), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == Arrow.LocalZonedDateTime + end + + @testset "tensor message boundary" begin + function patch_message_header_type(bytes, header_type::UInt8) + patched = copy(bytes) + msg = Arrow.FlatBuffers.getrootas(Arrow.Meta.Message, patched, 8) + offset = Arrow.FlatBuffers.offset(msg, 6) + @test offset != 0 + patched[Arrow.FlatBuffers.pos(msg) + offset + 1] = header_type + return patched + end + + base = take!(Arrow.tobuffer((x=[1, 2],))) + + tensor_bytes = patch_message_header_type(base, UInt8(4)) + @test_throws ArgumentError(Arrow.TENSOR_UNSUPPORTED) Arrow.Table(tensor_bytes) + @test_throws ArgumentError(Arrow.TENSOR_UNSUPPORTED) collect( + Arrow.Stream(tensor_bytes), + ) + + sparse_tensor_bytes = patch_message_header_type(base, UInt8(5)) + @test_throws ArgumentError(Arrow.SPARSE_TENSOR_UNSUPPORTED) Arrow.Table( + sparse_tensor_bytes, + ) + @test_throws ArgumentError(Arrow.SPARSE_TENSOR_UNSUPPORTED) collect( + Arrow.Stream(sparse_tensor_bytes), + ) + end + @testset "# 158" begin # arrow ipc stream generated from pyarrow with no record batches bytes = UInt8[ @@ -624,6 +1447,19 @@ end t = (col1=[["boop", "she"], ["boop", "she"], ["boo"]],) tbl = Arrow.Table(Arrow.tobuffer(t)) @test eltype(tbl.col1) <: AbstractVector{String} + + toffset = ( + col1=OffsetArray([Int64[1, 2], Int64[3, 4], Int64[]], -1:1), + col2=OffsetArray( + Union{Missing,Vector{Int64}}[Int64[1], missing, Int64[2, 3]], + -1:1, + ), + ) + tt = Arrow.Table(Arrow.tobuffer(toffset)) + @test eltype(tt.col1) <: AbstractVector{Int64} + @test Base.nonmissingtype(eltype(tt.col2)) <: AbstractVector{Int64} + @test collect(toffset.col1) == tt.col1 + @test isequal(collect(toffset.col2), tt.col2) end @testset "# 200 VersionNumber" begin @@ -632,6 +1468,27 @@ end @test eltype(tbl.col1) == VersionNumber end + @testset "offset struct string write paths" begin + rows = OffsetArray( + Union{Missing,NamedTuple{(:s,),Tuple{String}}}[ + (s="a",), + missing, + (s="bc",), + ], + -1:1, + ) + tt = Arrow.Table(Arrow.tobuffer((rows=rows,))) + @test Base.nonmissingtype(eltype(tt.rows)) == NamedTuple{(:s,),Tuple{String}} + @test isequal(collect(rows), tt.rows) + end + + @testset "Complex" begin + t = (col1=Union{ComplexF64,Missing}[1 + 2im, missing, 3 + 4im],) + tbl = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tbl.col1) == Union{ComplexF64,Missing} + @test isequal(collect(tbl.col1), t.col1) + end + @testset "`show`" begin str = nothing table = (; a=1:5, b=fill(1.0, 5)) @@ -852,6 +1709,95 @@ end @test_throws ArgumentError( "`keytype(d)` must be concrete to serialize map-like `d`, but `keytype(d) == Real`", ) Arrow.tobuffer(t) + + t = ( + x=OffsetArray([Dict("a" => 1, "b" => 2), Dict("c" => 3)], -1:0), + xm=OffsetArray( + Union{Missing,Dict{String,Int}}[Dict("a" => 1), missing], + -1:0, + ), + xe=OffsetArray( + [Dict("a" => 1, "b" => 2, "c" => 3), Dict{String,Int}()], + -1:0, + ), + xem=OffsetArray( + Union{Missing,Dict{String,Int}}[Dict{String,Int}(), missing], + -1:0, + ), + xa=OffsetArray(Any[Dict("a" => 1, "b" => 2), Dict("c" => 3)], -1:0), + xam=OffsetArray(Any[Dict("a" => 1), missing], -1:0), + xame=OffsetArray(Any[Dict{String,Int}(), missing], -1:0), + ) + tt = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tt.x) == Dict{String,Int64} + @test eltype(tt.xm) == Union{Missing,Dict{String,Int64}} + @test eltype(tt.xe) == Dict{String,Int64} + @test eltype(tt.xem) == Union{Missing,Dict{String,Int64}} + @test eltype(tt.xa) == Dict{String,Int64} + @test eltype(tt.xam) == Union{Missing,Dict{String,Int64}} + @test eltype(tt.xame) == Union{Missing,Dict{String,Int64}} + @test copy(tt.x) isa Vector{Dict{String,Int64}} + @test copy(tt.xm) isa Vector{Union{Missing,Dict{String,Int64}}} + @test copy(tt.xem) isa Vector{Union{Missing,Dict{String,Int64}}} + @test copy(tt.xa) isa Vector{Dict{String,Int64}} + @test copy(tt.xam) isa Vector{Union{Missing,Dict{String,Int64}}} + @test copy(tt.xame) isa Vector{Union{Missing,Dict{String,Int64}}} + @test collect(t.x) == tt.x + @test isequal(collect(t.xm), tt.xm) + @test collect(t.xe) == tt.xe + @test isequal(collect(t.xem), tt.xem) + @test collect(t.xa) == tt.xa + @test isequal(collect(t.xam), tt.xam) + @test isequal(collect(t.xame), tt.xame) + + mapio = IOBuffer() + Arrow.write(mapio, (x=t.xm,)) + seekstart(mapio) + @test read(Arrow.tobuffer((x=t.xm,))) == read(mapio) + + mapbuf = Arrow.tobuffer((x=t.xm,)) + seekend(mapbuf) + mappos = position(mapbuf) + Arrow.append(mapbuf, Arrow.Table(Arrow.tobuffer((x=t.xm,)))) + seekstart(mapbuf) + mapbuf1 = read(mapbuf, mappos) + mapbuf2 = read(mapbuf) + mapt1 = Arrow.Table(mapbuf1) + mapt2 = Arrow.Table(mapbuf2) + @test isequal(collect(mapt1.x), collect(mapt2.x)) + + emptymapbuf = Arrow.tobuffer((x=t.xe,)) + seekend(emptymapbuf) + emptymappos = position(emptymapbuf) + Arrow.append(emptymapbuf, Arrow.Table(Arrow.tobuffer((x=t.xe,)))) + seekstart(emptymapbuf) + emptymapbuf1 = read(emptymapbuf, emptymappos) + emptymapbuf2 = read(emptymapbuf) + emptymapt1 = Arrow.Table(emptymapbuf1) + emptymapt2 = Arrow.Table(emptymapbuf2) + @test isequal(collect(emptymapt1.x), collect(emptymapt2.x)) + + anymapbuf = Arrow.tobuffer((x=t.xam,)) + seekend(anymapbuf) + anymappos = position(anymapbuf) + Arrow.append(anymapbuf, Arrow.Table(Arrow.tobuffer((x=t.xam,)))) + seekstart(anymapbuf) + anymapbuf1 = read(anymapbuf, anymappos) + anymapbuf2 = read(anymapbuf) + anymapt1 = Arrow.Table(anymapbuf1) + anymapt2 = Arrow.Table(anymapbuf2) + @test isequal(collect(anymapt1.x), collect(anymapt2.x)) + + anyemptymapbuf = Arrow.tobuffer((x=t.xame,)) + seekend(anyemptymapbuf) + anyemptymappos = position(anyemptymapbuf) + Arrow.append(anyemptymapbuf, Arrow.Table(Arrow.tobuffer((x=t.xame,)))) + seekstart(anyemptymapbuf) + anyemptymapbuf1 = read(anyemptymapbuf, anyemptymappos) + anyemptymapbuf2 = read(anyemptymapbuf) + anyemptymapt1 = Arrow.Table(anyemptymapbuf1) + anyemptymapt2 = Arrow.Table(anyemptymapbuf2) + @test isequal(collect(anyemptymapt1.x), collect(anyemptymapt2.x)) end @testset "# 214" begin @@ -966,6 +1912,45 @@ end @test isequal(t1.bm, t2.bm) @test isequal(t1.c, t2.c) @test isequal(t1.cm, t2.cm) + + toffset = ( + b=OffsetArray([b"01", b"", b"3"], -1:1), + bm=OffsetArray( + Union{Missing,Base.CodeUnits{UInt8,String}}[b"01", b"3", missing], + -1:1, + ), + ba=OffsetArray(Any[b"01", b"", b"3"], -1:1), + bam=OffsetArray(Any[b"01", missing, b"3"], -1:1), + c=OffsetArray(["a", "b", "c"], -1:1), + cm=OffsetArray(Union{Missing,String}["a", "c", missing], -1:1), + ) + ttoffset = Arrow.Table(Arrow.tobuffer(toffset)) + @test eltype(ttoffset.b) <: Base.CodeUnits + @test Base.nonmissingtype(eltype(ttoffset.bm)) <: Base.CodeUnits + @test eltype(ttoffset.ba) <: Base.CodeUnits + @test Base.nonmissingtype(eltype(ttoffset.bam)) <: Base.CodeUnits + @test eltype(ttoffset.c) == String + @test eltype(ttoffset.cm) == Union{Missing,String} + @test collect(toffset.b) == ttoffset.b + @test isequal(collect(toffset.bm), ttoffset.bm) + @test collect(toffset.ba) == copy(ttoffset.ba) + @test isequal(collect(toffset.bam), copy(ttoffset.bam)) + @test collect(toffset.c) == ttoffset.c + @test isequal(collect(toffset.cm), ttoffset.cm) + + offsetbuf = Arrow.tobuffer(toffset) + seekend(offsetbuf) + offsetpos = position(offsetbuf) + Arrow.append(offsetbuf, ttoffset) + seekstart(offsetbuf) + offsetbuf1 = read(offsetbuf, offsetpos) + offsetbuf2 = read(offsetbuf) + offsett1 = Arrow.Table(offsetbuf1) + offsett2 = Arrow.Table(offsetbuf2) + @test collect(offsett1.b) == collect(offsett2.b) + @test isequal(collect(offsett1.bm), collect(offsett2.bm)) + @test collect(offsett1.c) == collect(offsett2.c) + @test isequal(collect(offsett1.cm), collect(offsett2.cm)) end @testset "# 435" begin