diff --git a/cpp/src/arrow/util/compression.cc b/cpp/src/arrow/util/compression.cc index d4788569732..1ae71eba096 100644 --- a/cpp/src/arrow/util/compression.cc +++ b/cpp/src/arrow/util/compression.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "arrow/result.h" #include "arrow/status.h" @@ -200,11 +201,16 @@ Result> Codec::Create(Compression::type codec_type, codec = internal::MakeLz4HadoopRawCodec(); #endif break; - case Compression::ZSTD: + case Compression::ZSTD: { #ifdef ARROW_WITH_ZSTD - codec = internal::MakeZSTDCodec(compression_level); + auto opt = dynamic_cast(&codec_options); + codec = internal::MakeZSTDCodec( + compression_level, + opt ? opt->compression_context_params : std::vector>{}, + opt ? opt->decompression_context_params : std::vector>{}); #endif break; + } case Compression::BZ2: #ifdef ARROW_WITH_BZ2 codec = internal::MakeBZ2Codec(compression_level); diff --git a/cpp/src/arrow/util/compression.h b/cpp/src/arrow/util/compression.h index f7bf4d5e12d..9e19def6bb5 100644 --- a/cpp/src/arrow/util/compression.h +++ b/cpp/src/arrow/util/compression.h @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include "arrow/result.h" #include "arrow/status.h" @@ -142,6 +144,16 @@ class ARROW_EXPORT BrotliCodecOptions : public CodecOptions { std::optional window_bits; }; +// ---------------------------------------------------------------------- +// Zstd codec options implementation + +class ARROW_EXPORT ZstdCodecOptions : public CodecOptions { + public: + // Valid keys can be found at https://facebook.github.io/zstd/zstd_manual.html. + std::vector> compression_context_params; + std::vector> decompression_context_params; +}; + /// \brief Compression codec class ARROW_EXPORT Codec { public: diff --git a/cpp/src/arrow/util/compression_internal.h b/cpp/src/arrow/util/compression_internal.h index ab2cf6d98b6..963ad0c04ca 100644 --- a/cpp/src/arrow/util/compression_internal.h +++ b/cpp/src/arrow/util/compression_internal.h @@ -18,6 +18,8 @@ #pragma once #include +#include +#include #include "arrow/util/compression.h" // IWYU pragma: export @@ -74,7 +76,9 @@ std::unique_ptr MakeLz4HadoopRawCodec(); constexpr int kZSTDDefaultCompressionLevel = 1; std::unique_ptr MakeZSTDCodec( - int compression_level = kZSTDDefaultCompressionLevel); + int compression_level = kZSTDDefaultCompressionLevel, + std::vector> compression_context_params = {}, + std::vector> decompression_context_params = {}); } // namespace internal } // namespace util diff --git a/cpp/src/arrow/util/compression_test.cc b/cpp/src/arrow/util/compression_test.cc index eb2da98d511..c197a83198e 100644 --- a/cpp/src/arrow/util/compression_test.cc +++ b/cpp/src/arrow/util/compression_test.cc @@ -16,12 +16,15 @@ // under the License. #include +#include #include #include #include #include #include +#include #include +#include #include #include @@ -446,36 +449,17 @@ TEST(TestCodecMisc, SpecifyCompressionLevel) { } } -TEST(TestCodecMisc, SpecifyCodecOptionsGZip) { - // for now only GZIP & Brotli codec options supported, since it has specific parameters - // to be customized, other codecs could directly go with CodecOptions, could add more - // specific codec options if needed. - struct CombinationOption { - int level; - GZipFormat format; - int window_bits; - bool expect_success; - }; - constexpr CombinationOption combinations[] = {{2, GZipFormat::ZLIB, 12, true}, - {9, GZipFormat::GZIP, 9, true}, - {9, GZipFormat::GZIP, 20, false}, - {5, GZipFormat::DEFLATE, -12, false}, - {-992, GZipFormat::GZIP, 15, false}}; - +template T> +void CheckSpecifyCodecOptions(Compression::type compression, + std::span> options) { std::vector data = MakeRandomData(2000); - for (const auto& combination : combinations) { - const auto compression = Compression::GZIP; + for (const auto& [codec_option, expect_success] : options) { if (!Codec::IsAvailable(compression)) { // Support for this codec hasn't been built continue; } - auto codec_options = arrow::util::GZipCodecOptions(); - codec_options.compression_level = combination.level; - codec_options.gzip_format = combination.format; - codec_options.window_bits = combination.window_bits; - const auto expect_success = combination.expect_success; - auto result1 = Codec::Create(compression, codec_options); - auto result2 = Codec::Create(compression, codec_options); + auto result1 = Codec::Create(compression, codec_option); + auto result2 = Codec::Create(compression, codec_option); ASSERT_EQ(expect_success, result1.ok()); ASSERT_EQ(expect_success, result2.ok()); if (expect_success) { @@ -484,37 +468,169 @@ TEST(TestCodecMisc, SpecifyCodecOptionsGZip) { } } +TEST(TestCodecMisc, SpecifyCodecOptionsGZip) { + auto make_option = [](int compression_level, GZipFormat format, + std::optional window_bits) { + arrow::util::GZipCodecOptions option; + option.compression_level = compression_level; + option.gzip_format = format; + option.window_bits = window_bits; + return option; + }; + const std::pair options[]{ + {make_option(5, GZipFormat::GZIP, 15), true}, + {make_option(9, GZipFormat::ZLIB, 12), true}, + {make_option(-1, GZipFormat::DEFLATE, 10), true}, + {make_option(10, GZipFormat::GZIP, 25), false}, + {make_option(-992, GZipFormat::GZIP, 15), false}, + }; + CheckSpecifyCodecOptions(Compression::GZIP, options); +} + TEST(TestCodecMisc, SpecifyCodecOptionsBrotli) { - // for now only GZIP & Brotli codec options supported, since it has specific parameters - // to be customized, other codecs could directly go with CodecOptions, could add more - // specific codec options if needed. - struct CombinationOption { - int level; - int window_bits; - bool expect_success; + auto make_option = [](int compression_level, std::optional window_bits) { + arrow::util::BrotliCodecOptions option; + option.compression_level = compression_level; + option.window_bits = window_bits; + return option; }; - constexpr CombinationOption combinations[] = { - {8, 22, true}, {11, 10, true}, {1, 24, true}, {5, -12, false}, {-992, 25, false}}; + const std::pair options[]{ + {make_option(8, 22), true}, {make_option(11, 10), true}, + {make_option(1, 24), true}, {make_option(5, -12), false}, + {make_option(-992, 25), false}, + }; + CheckSpecifyCodecOptions(Compression::BROTLI, options); +} - std::vector data = MakeRandomData(2000); - for (const auto& combination : combinations) { - const auto compression = Compression::BROTLI; - if (!Codec::IsAvailable(compression)) { - // Support for this codec hasn't been built - continue; +TEST(TestCodecMisc, SpecifyCodecOptionsZstd) { + auto make_option = [](int compression_level, + std::vector> compression_context_params, + std::vector> decompression_context_params) { + arrow::util::ZstdCodecOptions option; + option.compression_level = compression_level; + option.compression_context_params = std::move(compression_context_params); + option.decompression_context_params = std::move(decompression_context_params); + return option; + }; + constexpr int ZSTD_c_windowLog = 101; + const std::pair options[]{ + {make_option(2, {}, {}), true}, + {make_option(9, {}, {}), true}, + {make_option(15, {}, {}), true}, + {make_option(-992, {}, {}), true}, + {make_option(3, {{ZSTD_c_windowLog, 23}}, {}), true}, + {make_option(3, {{ZSTD_c_windowLog, 28}}, {}), true}}; + CheckSpecifyCodecOptions(Compression::ZSTD, options); +} + +TEST(TestCodecMisc, ZstdLargerWindowLog) { + constexpr int ZSTD_c_windowLog = 101; + + arrow::util::ZstdCodecOptions option1; + arrow::util::ZstdCodecOptions option2; + option2.compression_context_params = {{ZSTD_c_windowLog, 28}}; + + std::vector data = MakeRandomData(4 * 1024 * 1024); + data.reserve(data.size() * 2); + data.insert(data.end(), data.begin(), data.end()); + + auto compress = [&data](const arrow::util::ZstdCodecOptions& codecOption) + -> Result> { + ARROW_ASSIGN_OR_RAISE(auto codec, Codec::Create(Compression::ZSTD, codecOption)); + auto max_compressed_len = codec->MaxCompressedLen(data.size(), data.data()); + std::vector compressed(max_compressed_len); + + ARROW_ASSIGN_OR_RAISE( + auto actual_size, + codec->Compress(data.size(), data.data(), max_compressed_len, compressed.data())); + compressed.resize(actual_size); + return compressed; + }; + + ASSERT_OK_AND_ASSIGN(auto compressed1, compress(option1)); + ASSERT_OK_AND_ASSIGN(auto compressed2, compress(option2)); + ASSERT_GT(compressed1.size(), compressed2.size()); +} + +TEST(TestCodecMisc, ZstdStreamLargerWindowLog) { + constexpr int ZSTD_c_windowLog = 101; + constexpr int ZSTD_d_windowLogMax = 100; + + arrow::util::ZstdCodecOptions option1; + arrow::util::ZstdCodecOptions option2; + option2.compression_context_params = {{ZSTD_c_windowLog, 28}}; + option2.decompression_context_params = {{ZSTD_d_windowLogMax, 28}}; + + std::vector data = MakeRandomData(4 * 1024 * 1024); + data.reserve(data.size() * 2); + data.insert(data.end(), data.begin(), data.end()); + + ASSERT_OK_AND_ASSIGN(auto codec1, Codec::Create(Compression::ZSTD, option1)); + ASSERT_OK_AND_ASSIGN(auto codec2, Codec::Create(Compression::ZSTD, option2)); + + auto compress = [&data](Codec& codec) -> Result> { + auto max_compressed_len = codec.MaxCompressedLen(data.size(), data.data()); + std::vector compressed(max_compressed_len); + + int64_t bytes_written = 0; + int64_t bytes_read = 0; + ARROW_ASSIGN_OR_RAISE(auto compressor, codec.MakeCompressor()); + while (bytes_read < static_cast(data.size())) { + ARROW_ASSIGN_OR_RAISE( + auto result, + compressor->Compress(data.size() - bytes_read, data.data() + bytes_read, + max_compressed_len - bytes_written, + compressed.data() + bytes_written)); + bytes_written += result.bytes_written; + bytes_read += result.bytes_read; } - auto codec_options = arrow::util::BrotliCodecOptions(); - codec_options.compression_level = combination.level; - codec_options.window_bits = combination.window_bits; - const auto expect_success = combination.expect_success; - auto result1 = Codec::Create(compression, codec_options); - auto result2 = Codec::Create(compression, codec_options); - ASSERT_EQ(expect_success, result1.ok()); - ASSERT_EQ(expect_success, result2.ok()); - if (expect_success) { - CheckCodecRoundtrip(*result1, *result2, data); + while (true) { + ARROW_ASSIGN_OR_RAISE(auto result, + compressor->End(max_compressed_len - bytes_written, + compressed.data() + bytes_written)); + bytes_written += result.bytes_written; + if (!result.should_retry) { + break; + } + } + compressed.resize(bytes_written); + return compressed; + }; + + ASSERT_OK_AND_ASSIGN(auto compressed1, compress(*codec1)); + ASSERT_OK_AND_ASSIGN(auto compressed2, compress(*codec2)); + ASSERT_GT(compressed1.size(), compressed2.size()); + + ASSERT_OK_AND_ASSIGN(auto decompressor1, codec1->MakeDecompressor()); + ASSERT_OK_AND_ASSIGN(auto decompressor2, codec2->MakeDecompressor()); + + std::vector decompressed(data.size()); + // Using a windowLog greater than ZSTD_WINDOWLOG_LIMIT_DEFAULT(1 << 27) at compression + // stage requires explicitly allowing such size at streaming decompression stage. + auto ret = decompressor1->Decompress(compressed2.size(), compressed2.data(), + decompressed.size(), decompressed.data()); + ASSERT_NOT_OK(ret); + ASSERT_EQ(ret.status().message(), + "ZSTD decompress failed: Frame requires too much memory for decoding"); + + int64_t bytes_written = 0; + int64_t bytes_read = 0; + while (true) { + ASSERT_OK_AND_ASSIGN(auto result, + decompressor2->Decompress(compressed2.size() - bytes_read, + compressed2.data() + bytes_read, + decompressed.size() - bytes_written, + decompressed.data() + bytes_written)); + bytes_read += result.bytes_read; + bytes_written += result.bytes_written; + if (!result.need_more_output) { + break; } } + ASSERT_TRUE(decompressor2->IsFinished()); + ASSERT_EQ(bytes_read, compressed2.size()); + ASSERT_EQ(bytes_written, data.size()); + ASSERT_EQ(decompressed, data); } TEST_P(CodecTest, MinMaxCompressionLevel) { diff --git a/cpp/src/arrow/util/compression_zstd.cc b/cpp/src/arrow/util/compression_zstd.cc index 8a8a6d46196..bc46c0ea4aa 100644 --- a/cpp/src/arrow/util/compression_zstd.cc +++ b/cpp/src/arrow/util/compression_zstd.cc @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include @@ -36,6 +38,9 @@ namespace internal { namespace { +using CCtxPtr = std::unique_ptr; +using DCtxPtr = std::unique_ptr; + Status ZSTDError(size_t ret, const char* prefix_msg) { return Status::IOError(prefix_msg, ZSTD_getErrorName(ret)); } @@ -45,19 +50,7 @@ Status ZSTDError(size_t ret, const char* prefix_msg) { class ZSTDDecompressor : public Decompressor { public: - ZSTDDecompressor() : stream_(ZSTD_createDStream()) {} - - ~ZSTDDecompressor() override { ZSTD_freeDStream(stream_); } - - Status Init() { - finished_ = false; - size_t ret = ZSTD_initDStream(stream_); - if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD init failed: "); - } else { - return Status::OK(); - } - } + explicit ZSTDDecompressor(DCtxPtr stream) : stream_(std::move(stream)) {} Result Decompress(int64_t input_len, const uint8_t* input, int64_t output_len, uint8_t* output) override { @@ -72,7 +65,7 @@ class ZSTDDecompressor : public Decompressor { out_buf.pos = 0; size_t ret; - ret = ZSTD_decompressStream(stream_, &out_buf, &in_buf); + ret = ZSTD_decompressStream(stream_.get(), &out_buf, &in_buf); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD decompress failed: "); } @@ -82,13 +75,20 @@ class ZSTDDecompressor : public Decompressor { in_buf.pos == 0 && out_buf.pos == 0}; } - Status Reset() override { return Init(); } + Status Reset() override { + finished_ = false; + auto ret = ZSTD_DCtx_reset(stream_.get(), ZSTD_reset_session_only); + if (ZSTD_isError(ret)) { + return ZSTDError(ret, "ZSTD reset failed: "); + } + return {}; + } bool IsFinished() override { return finished_; } - protected: - ZSTD_DStream* stream_; - bool finished_; + private: + DCtxPtr stream_; + bool finished_{false}; }; // ---------------------------------------------------------------------- @@ -96,19 +96,7 @@ class ZSTDDecompressor : public Decompressor { class ZSTDCompressor : public Compressor { public: - explicit ZSTDCompressor(int compression_level) - : stream_(ZSTD_createCStream()), compression_level_(compression_level) {} - - ~ZSTDCompressor() override { ZSTD_freeCStream(stream_); } - - Status Init() { - size_t ret = ZSTD_initCStream(stream_, compression_level_); - if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD init failed: "); - } else { - return Status::OK(); - } - } + explicit ZSTDCompressor(CCtxPtr stream) : stream_(std::move(stream)) {} Result Compress(int64_t input_len, const uint8_t* input, int64_t output_len, uint8_t* output) override { @@ -123,7 +111,7 @@ class ZSTDCompressor : public Compressor { out_buf.pos = 0; size_t ret; - ret = ZSTD_compressStream(stream_, &out_buf, &in_buf); + ret = ZSTD_compressStream(stream_.get(), &out_buf, &in_buf); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD compress failed: "); } @@ -139,7 +127,7 @@ class ZSTDCompressor : public Compressor { out_buf.pos = 0; size_t ret; - ret = ZSTD_flushStream(stream_, &out_buf); + ret = ZSTD_flushStream(stream_.get(), &out_buf); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD flush failed: "); } @@ -154,18 +142,15 @@ class ZSTDCompressor : public Compressor { out_buf.pos = 0; size_t ret; - ret = ZSTD_endStream(stream_, &out_buf); + ret = ZSTD_endStream(stream_.get(), &out_buf); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD end failed: "); } return EndResult{static_cast(out_buf.pos), ret > 0}; } - protected: - ZSTD_CStream* stream_; - private: - int compression_level_; + CCtxPtr stream_; }; // ---------------------------------------------------------------------- @@ -173,10 +158,14 @@ class ZSTDCompressor : public Compressor { class ZSTDCodec : public Codec { public: - explicit ZSTDCodec(int compression_level) + explicit ZSTDCodec(int compression_level, + std::vector> compression_context_params, + std::vector> decompression_context_params) : compression_level_(compression_level == kUseDefaultCompressionLevel ? kZSTDDefaultCompressionLevel - : compression_level) {} + : compression_level), + compression_context_params_(std::move(compression_context_params)), + decompression_context_params_(std::move(decompression_context_params)) {} Result Decompress(int64_t input_len, const uint8_t* input, int64_t output_buffer_len, uint8_t* output_buffer) override { @@ -188,8 +177,10 @@ class ZSTDCodec : public Codec { output_buffer = &empty_buffer; } - size_t ret = ZSTD_decompress(output_buffer, static_cast(output_buffer_len), - input, static_cast(input_len)); + ARROW_ASSIGN_OR_RAISE(auto dctx, CreateDCtx()); + size_t ret = ZSTD_decompressDCtx(dctx.get(), output_buffer, + static_cast(output_buffer_len), input, + static_cast(input_len)); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD decompression failed: "); } @@ -207,8 +198,10 @@ class ZSTDCodec : public Codec { Result Compress(int64_t input_len, const uint8_t* input, int64_t output_buffer_len, uint8_t* output_buffer) override { - size_t ret = ZSTD_compress(output_buffer, static_cast(output_buffer_len), - input, static_cast(input_len), compression_level_); + ARROW_ASSIGN_OR_RAISE(auto cctx, CreateCCtx()); + size_t ret = + ZSTD_compress2(cctx.get(), output_buffer, static_cast(output_buffer_len), + input, static_cast(input_len)); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD compression failed: "); } @@ -216,15 +209,13 @@ class ZSTDCodec : public Codec { } Result> MakeCompressor() override { - auto ptr = std::make_shared(compression_level_); - RETURN_NOT_OK(ptr->Init()); - return ptr; + ARROW_ASSIGN_OR_RAISE(auto cctx, CreateCCtx()); + return std::make_shared(std::move(cctx)); } Result> MakeDecompressor() override { - auto ptr = std::make_shared(); - RETURN_NOT_OK(ptr->Init()); - return ptr; + ARROW_ASSIGN_OR_RAISE(auto dctx, CreateDCtx()); + return std::make_shared(std::move(dctx)); } Compression::type compression_type() const override { return Compression::ZSTD; } @@ -235,13 +226,47 @@ class ZSTDCodec : public Codec { int compression_level() const override { return compression_level_; } private: + Result CreateCCtx() const { + CCtxPtr cctx{ZSTD_createCCtx(), ZSTD_freeCCtx}; + auto ret = + ZSTD_CCtx_setParameter(cctx.get(), ZSTD_c_compressionLevel, compression_level_); + if (ZSTD_isError(ret)) { + return ZSTDError(ret, "ZSTD_CCtx create failed: "); + } + for (auto& [key, value] : compression_context_params_) { + ret = ZSTD_CCtx_setParameter(cctx.get(), static_cast(key), value); + if (ZSTD_isError(ret)) { + return ZSTDError(ret, "ZSTD_CCtx create failed: "); + } + } + return cctx; + } + + Result CreateDCtx() const { + DCtxPtr dctx{ZSTD_createDCtx(), ZSTD_freeDCtx}; + for (auto& [key, value] : decompression_context_params_) { + auto ret = + ZSTD_DCtx_setParameter(dctx.get(), static_cast(key), value); + if (ZSTD_isError(ret)) { + return ZSTDError(ret, "ZSTD_DCtx create failed: "); + } + } + return dctx; + } + const int compression_level_; + const std::vector> compression_context_params_; + const std::vector> decompression_context_params_; }; } // namespace -std::unique_ptr MakeZSTDCodec(int compression_level) { - return std::make_unique(compression_level); +std::unique_ptr MakeZSTDCodec( + int compression_level, std::vector> compression_context_params, + std::vector> decompression_context_params) { + return std::make_unique(compression_level, + std::move(compression_context_params), + std::move(decompression_context_params)); } } // namespace internal diff --git a/cpp/src/parquet/column_writer_test.cc b/cpp/src/parquet/column_writer_test.cc index 20ae916ae0d..157e73ffec4 100644 --- a/cpp/src/parquet/column_writer_test.cc +++ b/cpp/src/parquet/column_writer_test.cc @@ -703,6 +703,13 @@ TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndZstdCompression) { this->TestRequiredWithSettings(Encoding::PLAIN, Compression::ZSTD, false, true, LARGE_SIZE); } +TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithZstdCodecOptions) { + constexpr int ZSTD_c_windowLog = 101; + auto codec_options = std::make_shared<::arrow::util::ZstdCodecOptions>(); + codec_options->compression_context_params = {{ZSTD_c_windowLog, 23}}; + this->TestRequiredWithCodecOptions(Encoding::PLAIN, Compression::ZSTD, false, false, + LARGE_SIZE, codec_options); +} #endif TYPED_TEST(TestPrimitiveWriter, Optional) { diff --git a/cpp/src/parquet/properties_test.cc b/cpp/src/parquet/properties_test.cc index 7be352aa5f1..a0df0d30487 100644 --- a/cpp/src/parquet/properties_test.cc +++ b/cpp/src/parquet/properties_test.cc @@ -81,9 +81,13 @@ TEST(TestWriterProperties, AdvancedHandling) { } TEST(TestWriterProperties, SetCodecOptions) { + constexpr int ZSTD_c_windowLog = 101; + WriterProperties::Builder builder; builder.compression("gzip", Compression::GZIP); - builder.compression("zstd", Compression::ZSTD); + auto zstd_codec_options = std::make_shared<::arrow::util::ZstdCodecOptions>(); + zstd_codec_options->compression_context_params = {{ZSTD_c_windowLog, 23}}; + builder.codec_options("zstd", zstd_codec_options); builder.compression("brotli", Compression::BROTLI); auto gzip_codec_options = std::make_shared<::arrow::util::GZipCodecOptions>(); gzip_codec_options->compression_level = 5;