diff --git a/DCVC-RT/README.md b/DCVC-RT/README.md
new file mode 100644
index 0000000..8a0018e
--- /dev/null
+++ b/DCVC-RT/README.md
@@ -0,0 +1,113 @@
+# Introduction
+
+Official Pytorch implementation for DCVC-RT: [Towards Practical **R**eal-**T**ime Neural Video Compression](https://arxiv.org/abs/2502.20762), in CVPR 2025.
+
+# Prerequisites
+* Python 3.12 and conda, get [Conda](https://www.anaconda.com/)
+* CUDA 12.6 (other versions may also work. Make sure the CUDA version matches with pytorch.)
+* pytorch (We have tested that pytorch-2.6 works. Other versions may also work.)
+* Environment
+ ```
+ conda create -n $YOUR_PY_ENV_NAME python=3.12
+ conda activate $YOUR_PY_ENV_NAME
+
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
+ pip install -r requirements.txt
+ ```
+
+# Test dataset
+
+We support arbitrary original resolution. The input video resolution will be padded automatically. The reconstructed video will be cropped back to the original size. The distortion (PSNR) is calculated at original resolution.
+
+## YUV 420 content
+
+Put *.yuv in the folder structure similar to the following structure.
+
+ /media/data/HEVC_B/
+ - BQTerrace_1920x1080_60.yuv
+ - BasketballDrive_1920x1080_50.yuv
+ - ...
+ /media/data/HEVC_D/
+ /media/data/HEVC_C/
+ ...
+
+The dataset structure can be seen in dataset_config_example_yuv420.json.
+
+## RGB content
+
+We highly suggest testing YUV420 content. To test RGB content, please refer to the [DCVC-FM](../DCVC-FM) folder.
+
+# Build the project
+Please build the C++ code to support bitstream writing and customized CUDA kernels to fuse operations.
+
+```bash
+sudo apt-get install cmake g++ ninja-build
+conda activate $YOUR_PY_ENV_NAME
+cd ./src/cpp/
+pip install .
+cd ../layers/extensions/inference/
+pip install .
+```
+
+# CPU performance scaling
+
+Note that the arithmetic coding runs on the CPU, please make sure your CPU runs at high performance while writing the actual bitstream. Otherwise, the arithmetic coding may take a long time.
+
+Check the CPU frequency by
+```
+grep -E '^model name|^cpu MHz' /proc/cpuinfo
+```
+
+Run the following command to maximum CPU frequency
+```
+echo performance | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
+```
+
+Run the following command to recover the default frequency
+```
+echo ondemand | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
+```
+
+# Pretrained models
+
+* Download [our pretrained models](https://1drv.ms/f/c/2866592d5c55df8c/Esu0KJ-I2kxCjEP565ARx_YB88i0UnR6XnODqFcvZs4LcA?e=by8CO8) and put them into ./checkpoints folder.
+* There are 2 models, one for image coding and the other for video coding.
+
+# Test the models
+
+Example to test pretrained model with four rate points:
+```bash
+ python test_video.py --model_path_i ./checkpoints/cvpr2025_image.pth.tar --model_path_p ./checkpoints/cvpr2025_video.pth.tar --rate_num 4 --test_config ./dataset_config_example_yuv420.json --cuda 1 -w 1 --write_stream 1 --force_zero_thres 0.12 --output_path output.json --force_intra_period -1 --reset_interval 64 --force_frame_num -1 --check_existing 0
+```
+
+It is recommended that the ```-w``` number is equal to your GPU number.
+
+You can also specify different ```--rate_num``` values (2~64) to test finer bitrate adjustment.
+
+# Comparing with other method
+Bit saving over VTM-17.0 (UVG all frames with single intra-frame setting (i.e. intra-period = –1) and YUV420 colorspace.)
+
+
+
+The BD-Rate and encoding/decoding speed on Nvidia A100 GPU
+
+
+
+# Acknowledgement
+The implementation is based on [CompressAI](https://github.com/InterDigitalInc/CompressAI).
+
+# Citation
+If you find this work useful for your research, please cite:
+
+```
+@inproceedings{jia2025towards,
+ title={Towards Practical Real-Time Neural Video Compression},
+ author={Jia, Zhaoyang and Li, Bin and Li, Jiahao and Xie, Wenxuan and Qi, Linfeng and Li, Houqiang and Lu, Yan},
+ booktitle={{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
+ {CVPR} 2025, Nashville, TN, USA, June 11-25, 2024},
+ year={2025}
+}
+```
+
+# Trademarks
+This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
\ No newline at end of file
diff --git a/DCVC-RT/assets/RD-Curve.png b/DCVC-RT/assets/RD-Curve.png
new file mode 100644
index 0000000..1d93daf
Binary files /dev/null and b/DCVC-RT/assets/RD-Curve.png differ
diff --git a/DCVC-RT/assets/bd_rate_speed.png b/DCVC-RT/assets/bd_rate_speed.png
new file mode 100644
index 0000000..6da70d1
Binary files /dev/null and b/DCVC-RT/assets/bd_rate_speed.png differ
diff --git a/DCVC-RT/dataset_config_example_yuv420.json b/DCVC-RT/dataset_config_example_yuv420.json
new file mode 100644
index 0000000..3414f49
--- /dev/null
+++ b/DCVC-RT/dataset_config_example_yuv420.json
@@ -0,0 +1,100 @@
+{
+ "root_path": "/media/data/",
+ "test_classes": {
+ "UVG": {
+ "test": 1,
+ "base_path": "UVG",
+ "src_type": "yuv420",
+ "sequences": {
+ "Beauty_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "Bosphorus_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "HoneyBee_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "Jockey_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "ReadySteadyGo_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "ShakeNDry_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 300, "intra_period": -1},
+ "YachtRide_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}
+ }
+ },
+ "MCL-JCV": {
+ "test": 1,
+ "base_path": "MCL-JCV",
+ "src_type": "yuv420",
+ "sequences": {
+ "videoSRC01_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC02_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC03_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC04_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC05_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC06_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC07_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC08_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC09_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC10_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC11_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC12_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC13_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC14_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC15_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC16_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC17_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC18_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC19_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC20_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC21_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC22_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC23_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC24_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC25_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC26_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC27_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC28_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC29_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC30_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}
+ }
+ },
+ "HEVC_B": {
+ "test": 1,
+ "base_path": "HEVC_B",
+ "src_type": "yuv420",
+ "sequences": {
+ "BQTerrace_1920x1080_60.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "BasketballDrive_1920x1080_50.yuv": {"width": 1920, "height": 1080, "frames": 500, "intra_period": -1},
+ "Cactus_1920x1080_50.yuv": {"width": 1920, "height": 1080, "frames": 500, "intra_period": -1},
+ "Kimono1_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 240, "intra_period": -1},
+ "ParkScene_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 240, "intra_period": -1}
+ }
+ },
+ "HEVC_E": {
+ "test": 1,
+ "base_path": "HEVC_E",
+ "src_type": "yuv420",
+ "sequences": {
+ "FourPeople_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1},
+ "Johnny_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1},
+ "KristenAndSara_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1}
+ }
+ },
+ "HEVC_C": {
+ "test": 1,
+ "base_path": "HEVC_C",
+ "src_type": "yuv420",
+ "sequences": {
+ "BQMall_832x480_60.yuv": {"width": 832, "height": 480, "frames": 600, "intra_period": -1},
+ "BasketballDrill_832x480_50.yuv": {"width": 832, "height": 480, "frames": 500, "intra_period": -1},
+ "PartyScene_832x480_50.yuv": {"width": 832, "height": 480, "frames": 500, "intra_period": -1},
+ "RaceHorses_832x480_30.yuv": {"width": 832, "height": 480, "frames": 300, "intra_period": -1}
+ }
+ },
+ "HEVC_D": {
+ "test": 1,
+ "base_path": "HEVC_D",
+ "src_type": "yuv420",
+ "sequences": {
+ "BasketballPass_416x240_50.yuv": {"width": 416, "height": 240, "frames": 500, "intra_period": -1},
+ "BlowingBubbles_416x240_50.yuv": {"width": 416, "height": 240, "frames": 500, "intra_period": -1},
+ "BQSquare_416x240_60.yuv": {"width": 416, "height": 240, "frames": 600, "intra_period": -1},
+ "RaceHorses_416x240_30.yuv": {"width": 416, "height": 240, "frames": 300, "intra_period": -1}
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/DCVC-RT/requirements.txt b/DCVC-RT/requirements.txt
new file mode 100644
index 0000000..21ad776
--- /dev/null
+++ b/DCVC-RT/requirements.txt
@@ -0,0 +1,7 @@
+numpy>=1.20.0
+scipy
+matplotlib
+tqdm
+bd-metric
+pillow
+pybind11
diff --git a/DCVC-RT/src/cpp/py_rans/py_rans.cpp b/DCVC-RT/src/cpp/py_rans/py_rans.cpp
new file mode 100644
index 0000000..1f5d047
--- /dev/null
+++ b/DCVC-RT/src/cpp/py_rans/py_rans.cpp
@@ -0,0 +1,393 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include "py_rans.h"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace py = pybind11;
+
+RansEncoder::RansEncoder()
+{
+ m_encoder0 = std::make_shared();
+ m_encoder1 = std::make_shared();
+}
+
+void RansEncoder::encode_y(const py::array_t& symbols, const int cdf_group_index)
+{
+ py::buffer_info symbols_buf = symbols.request();
+ int16_t* symbols_ptr = static_cast(symbols_buf.ptr);
+
+ int symbolSize = static_cast(symbols.size());
+ if (m_use_two_encoders) {
+ int symbolSize0 = symbolSize / 2;
+ int symbolSize1 = symbolSize - symbolSize0;
+
+ auto vec_symbols0 = std::make_shared>(symbolSize0);
+ memcpy(vec_symbols0->data(), symbols_ptr, symbolSize0 * sizeof(int16_t));
+ m_encoder0->encode_y(vec_symbols0, cdf_group_index);
+ auto vec_symbols1 = std::make_shared>(symbolSize1);
+ memcpy(vec_symbols1->data(), symbols_ptr + symbolSize0, symbolSize1 * sizeof(int16_t));
+ m_encoder1->encode_y(vec_symbols1, cdf_group_index);
+ } else {
+ auto vec_symbols0 = std::make_shared>(symbolSize);
+ memcpy(vec_symbols0->data(), symbols_ptr, symbolSize * sizeof(int16_t));
+ m_encoder0->encode_y(vec_symbols0, cdf_group_index);
+ }
+}
+
+void RansEncoder::encode_z(const py::array_t& symbols, const int cdf_group_index,
+ const int start_offset, const int per_channel_size)
+{
+ py::buffer_info symbols_buf = symbols.request();
+ int8_t* symbols_ptr = static_cast(symbols_buf.ptr);
+
+ int symbolSize = static_cast(symbols.size());
+ if (m_use_two_encoders) {
+ int symbolSize0 = symbolSize / 2;
+ int symbolSize1 = symbolSize - symbolSize0;
+ int channel_half = symbolSize0 / per_channel_size;
+
+ auto vec_symbols0 = std::make_shared>(symbolSize0);
+ memcpy(vec_symbols0->data(), symbols_ptr, symbolSize0 * sizeof(int8_t));
+ m_encoder0->encode_z(vec_symbols0, cdf_group_index, start_offset, per_channel_size);
+ auto vec_symbols1 = std::make_shared>(symbolSize1);
+ memcpy(vec_symbols1->data(), symbols_ptr + symbolSize0, symbolSize1 * sizeof(int8_t));
+ m_encoder1->encode_z(vec_symbols1, cdf_group_index, start_offset + channel_half,
+ per_channel_size);
+ } else {
+ auto vec_symbols0 = std::make_shared>(symbolSize);
+ memcpy(vec_symbols0->data(), symbols_ptr, symbolSize * sizeof(int8_t));
+ m_encoder0->encode_z(vec_symbols0, cdf_group_index, start_offset, per_channel_size);
+ }
+}
+
+int RansEncoder::add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes,
+ const py::array_t& offsets)
+{
+ py::buffer_info cdfs_sizes_buf = cdfs_sizes.request();
+ py::buffer_info offsets_buf = offsets.request();
+ int32_t* cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr);
+ int32_t* offsets_ptr = static_cast(offsets_buf.ptr);
+
+ int cdf_num = static_cast(cdfs_sizes.size());
+ auto vec_cdfs_sizes = std::make_shared>(cdf_num);
+ memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num);
+ auto vec_offsets = std::make_shared>(offsets.size());
+ memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size());
+
+ int per_vector_size = static_cast(cdfs.size() / cdf_num);
+ auto vec_cdfs = std::make_shared>>(cdf_num);
+ auto cdfs_raw = cdfs.unchecked<2>();
+ for (int i = 0; i < cdf_num; i++) {
+ std::vector t(per_vector_size);
+ memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size);
+ vec_cdfs->at(i) = t;
+ }
+
+ int cdf_idx = m_encoder0->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets);
+ m_encoder1->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets);
+ return cdf_idx;
+}
+
+void RansEncoder::empty_cdf_buffer()
+{
+ m_encoder0->empty_cdf_buffer();
+ m_encoder1->empty_cdf_buffer();
+}
+
+void RansEncoder::flush()
+{
+ m_encoder0->flush();
+ m_encoder1->flush();
+}
+
+py::array_t RansEncoder::get_encoded_stream()
+{
+ if (m_use_two_encoders) {
+ auto result0 = m_encoder0->get_encoded_stream();
+ int nbytes0 = static_cast(result0->size());
+ auto result1 = m_encoder1->get_encoded_stream();
+ int nbytes1 = static_cast(result1->size());
+
+ int identical_bytes = 0;
+ int check_bytes = std::min(nbytes0, nbytes1);
+ check_bytes = std::min(check_bytes, 8);
+ for (int i = 0; i < check_bytes; i++) {
+ if (result0->at(nbytes0 - 1 - i) != 0) {
+ break;
+ }
+ if (result1->at(nbytes1 - 1 - i) != 0) {
+ break;
+ }
+ identical_bytes++;
+ }
+ if (identical_bytes == 0 && result0->at(nbytes0 - 1) == result1->at(nbytes1 - 1)) {
+ identical_bytes = 1;
+ }
+
+ py::array_t stream(nbytes0 + nbytes1 - identical_bytes);
+ py::buffer_info stream_buf = stream.request();
+ uint8_t* stream_ptr = static_cast(stream_buf.ptr);
+
+ std::copy(result0->begin(), result0->end(), stream_ptr);
+ std::reverse_copy(result1->begin(), result1->end() - identical_bytes, stream_ptr + nbytes0);
+ return stream;
+ }
+
+ auto result0 = m_encoder0->get_encoded_stream();
+ int nbytes0 = static_cast(result0->size());
+
+ py::array_t stream(nbytes0);
+ py::buffer_info stream_buf = stream.request();
+ uint8_t* stream_ptr = static_cast(stream_buf.ptr);
+
+ std::copy(result0->begin(), result0->end(), stream_ptr);
+ return stream;
+}
+
+void RansEncoder::reset()
+{
+ m_encoder0->reset();
+ m_encoder1->reset();
+}
+
+void RansEncoder::set_use_two_encoders(bool b)
+{
+ m_use_two_encoders = b;
+}
+
+bool RansEncoder::get_use_two_encoders()
+{
+ return m_use_two_encoders;
+}
+
+RansDecoder::RansDecoder()
+{
+ m_decoder0 = std::make_shared();
+ m_decoder1 = std::make_shared();
+}
+
+void RansDecoder::set_stream(const py::array_t& encoded)
+{
+ py::buffer_info encoded_buf = encoded.request();
+ const uint8_t* encoded_ptr = static_cast(encoded_buf.ptr);
+ const int encoded_size = static_cast(encoded.size());
+ auto stream0 = std::make_shared>(encoded.size());
+ std::copy(encoded_ptr, encoded_ptr + encoded_size, stream0->data());
+ m_decoder0->set_stream(stream0);
+ if (m_use_two_decoders) {
+ auto stream1 = std::make_shared>(encoded.size());
+ std::reverse_copy(encoded_ptr, encoded_ptr + encoded_size, stream1->data());
+ m_decoder1->set_stream(stream1);
+ }
+}
+
+void RansDecoder::decode_y(const py::array_t& indexes, const int cdf_group_index)
+{
+ py::buffer_info indexes_buf = indexes.request();
+ uint8_t* indexes_ptr = static_cast(indexes_buf.ptr);
+
+ int indexSize = static_cast(indexes.size());
+ if (m_use_two_decoders) {
+ int indexSize0 = indexSize / 2;
+ int indexSize1 = indexSize - indexSize0;
+
+ auto vec_indexes0 = std::make_shared>(indexSize0);
+ std::copy(indexes_ptr, indexes_ptr + indexSize0, vec_indexes0->data());
+ m_decoder0->decode_y(vec_indexes0, cdf_group_index);
+
+ auto vec_indexes1 = std::make_shared>(indexSize1);
+ std::copy(indexes_ptr + indexSize0, indexes_ptr + indexSize, vec_indexes1->data());
+ m_decoder1->decode_y(vec_indexes1, cdf_group_index);
+ } else {
+ auto vec_indexes0 = std::make_shared>(indexSize);
+ std::copy(indexes_ptr, indexes_ptr + indexSize, vec_indexes0->data());
+ m_decoder0->decode_y(vec_indexes0, cdf_group_index);
+ }
+}
+
+py::array_t RansDecoder::decode_and_get_y(const py::array_t& indexes,
+ const int cdf_group_index)
+{
+ decode_y(indexes, cdf_group_index);
+ return get_decoded_tensor();
+}
+
+void RansDecoder::decode_z(const int total_size, const int cdf_group_index, const int start_offset,
+ const int per_channel_size)
+{
+ if (m_use_two_decoders) {
+ int symbolSize0 = total_size / 2;
+ int symbolSize1 = total_size - symbolSize0;
+ int channel_half = symbolSize0 / per_channel_size;
+ m_decoder0->decode_z(symbolSize0, cdf_group_index, start_offset, per_channel_size);
+ m_decoder1->decode_z(symbolSize1, cdf_group_index, start_offset + channel_half,
+ per_channel_size);
+ } else {
+ m_decoder0->decode_z(total_size, cdf_group_index, start_offset, per_channel_size);
+ }
+}
+
+py::array_t RansDecoder::get_decoded_tensor()
+{
+ if (m_use_two_decoders) {
+ auto result0 = m_decoder0->get_decoded_tensor();
+ const int total_size0 = static_cast(result0->size());
+
+ auto result1 = m_decoder1->get_decoded_tensor();
+ const int total_size1 = static_cast(result1->size());
+ py::array_t output(total_size0 + total_size1);
+ py::buffer_info buf = output.request();
+ int8_t* buf_ptr = static_cast(buf.ptr);
+ std::copy(result0->begin(), result0->end(), buf_ptr);
+ std::copy(result1->begin(), result1->end(), buf_ptr + total_size0);
+
+ return output;
+ }
+
+ auto result0 = m_decoder0->get_decoded_tensor();
+ const int total_size0 = static_cast(result0->size());
+
+ py::array_t output(total_size0);
+ py::buffer_info buf = output.request();
+ int8_t* buf_ptr = static_cast(buf.ptr);
+ std::copy(result0->begin(), result0->end(), buf_ptr);
+
+ return output;
+}
+
+int RansDecoder::add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes,
+ const py::array_t& offsets)
+{
+ py::buffer_info cdfs_sizes_buf = cdfs_sizes.request();
+ py::buffer_info offsets_buf = offsets.request();
+ int32_t* cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr);
+ int32_t* offsets_ptr = static_cast(offsets_buf.ptr);
+
+ int cdf_num = static_cast(cdfs_sizes.size());
+ auto vec_cdfs_sizes = std::make_shared>(cdf_num);
+ memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num);
+ auto vec_offsets = std::make_shared>(offsets.size());
+ memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size());
+
+ int per_vector_size = static_cast(cdfs.size() / cdf_num);
+ auto vec_cdfs = std::make_shared>>(cdf_num);
+ auto cdfs_raw = cdfs.unchecked<2>();
+ for (int i = 0; i < cdf_num; i++) {
+ std::vector t(per_vector_size);
+ memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size);
+ vec_cdfs->at(i) = t;
+ }
+ int cdf_idx = m_decoder0->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets);
+ m_decoder1->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets);
+ return cdf_idx;
+}
+
+void RansDecoder::empty_cdf_buffer()
+{
+ m_decoder0->empty_cdf_buffer();
+ m_decoder1->empty_cdf_buffer();
+}
+
+void RansDecoder::set_use_two_decoders(bool b)
+{
+ m_use_two_decoders = b;
+}
+
+bool RansDecoder::get_use_two_decoders()
+{
+ return m_use_two_decoders;
+}
+
+std::vector pmf_to_quantized_cdf(const std::vector& pmf, int precision)
+{
+ /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal
+ * although it's only run once per model after training. See TF/compression
+ * implementation for an optimized version. */
+
+ std::vector cdf(pmf.size() + 1);
+ cdf[0] = 0; /* freq 0 */
+
+ std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) {
+ return static_cast(std::round(p * (1 << precision)) + 0.5);
+ });
+
+ const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0);
+
+ std::transform(cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) {
+ return static_cast((((1ull << precision) * p) / total));
+ });
+
+ std::partial_sum(cdf.begin(), cdf.end(), cdf.begin());
+ cdf.back() = 1 << precision;
+
+ for (int i = 0; i < static_cast(cdf.size() - 1); ++i) {
+ if (cdf[i] == cdf[i + 1]) {
+ /* Try to steal frequency from low-frequency symbols */
+ uint32_t best_freq = ~0u;
+ int best_steal = -1;
+ for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) {
+ uint32_t freq = cdf[j + 1] - cdf[j];
+ if (freq > 1 && freq < best_freq) {
+ best_freq = freq;
+ best_steal = j;
+ }
+ }
+
+ assert(best_steal != -1);
+
+ if (best_steal < i) {
+ for (int j = best_steal + 1; j <= i; ++j) {
+ cdf[j]--;
+ }
+ } else {
+ assert(best_steal > i);
+ for (int j = i + 1; j <= best_steal; ++j) {
+ cdf[j]++;
+ }
+ }
+ }
+ }
+
+ assert(cdf[0] == 0);
+ assert(cdf.back() == (1u << precision));
+ for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) {
+ assert(cdf[i + 1] > cdf[i]);
+ }
+
+ return cdf;
+}
+
+PYBIND11_MODULE(MLCodec_extensions_cpp, m)
+{
+ py::class_(m, "RansEncoder")
+ .def(py::init<>())
+ .def("encode_y", &RansEncoder::encode_y)
+ .def("encode_z", &RansEncoder::encode_z)
+ .def("flush", &RansEncoder::flush)
+ .def("get_encoded_stream", &RansEncoder::get_encoded_stream)
+ .def("reset", &RansEncoder::reset)
+ .def("add_cdf", &RansEncoder::add_cdf)
+ .def("empty_cdf_buffer", &RansEncoder::empty_cdf_buffer)
+ .def("set_use_two_encoders", &RansEncoder::set_use_two_encoders)
+ .def("get_use_two_encoders", &RansEncoder::get_use_two_encoders);
+
+ py::class_(m, "RansDecoder")
+ .def(py::init<>())
+ .def("set_stream", &RansDecoder::set_stream)
+ .def("decode_y", &RansDecoder::decode_y)
+ .def("decode_and_get_y", &RansDecoder::decode_and_get_y)
+ .def("decode_z", &RansDecoder::decode_z)
+ .def("get_decoded_tensor", &RansDecoder::get_decoded_tensor)
+ .def("add_cdf", &RansDecoder::add_cdf)
+ .def("empty_cdf_buffer", &RansDecoder::empty_cdf_buffer)
+ .def("set_use_two_decoders", &RansDecoder::set_use_two_decoders)
+ .def("get_use_two_decoders", &RansDecoder::get_use_two_decoders);
+
+ m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, "Return quantized CDF for a given PMF");
+}
diff --git a/DCVC-RT/src/cpp/py_rans/py_rans.h b/DCVC-RT/src/cpp/py_rans/py_rans.h
new file mode 100644
index 0000000..c7d223a
--- /dev/null
+++ b/DCVC-RT/src/cpp/py_rans/py_rans.h
@@ -0,0 +1,71 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#pragma once
+#include "rans.h"
+#include
+
+#include
+#include
+#include
+
+namespace py = pybind11;
+
+// the classes in this file only perform the type conversion
+// from python type (numpy) to C++ type (vector)
+class RansEncoder {
+public:
+ RansEncoder();
+
+ RansEncoder(const RansEncoder&) = delete;
+ RansEncoder(RansEncoder&&) = delete;
+ RansEncoder& operator=(const RansEncoder&) = delete;
+ RansEncoder& operator=(RansEncoder&&) = delete;
+
+ void encode_y(const py::array_t& symbols, const int cdf_group_index);
+ void encode_z(const py::array_t& symbols, const int cdf_group_index,
+ const int start_offset, const int per_channel_size);
+ void flush();
+ py::array_t get_encoded_stream();
+ void reset();
+ int add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes,
+ const py::array_t& offsets);
+ void empty_cdf_buffer();
+ void set_use_two_encoders(bool b);
+ bool get_use_two_encoders();
+
+private:
+ std::shared_ptr m_encoder0;
+ std::shared_ptr m_encoder1;
+ bool m_use_two_encoders{ false };
+};
+
+class RansDecoder {
+public:
+ RansDecoder();
+
+ RansDecoder(const RansDecoder&) = delete;
+ RansDecoder(RansDecoder&&) = delete;
+ RansDecoder& operator=(const RansDecoder&) = delete;
+ RansDecoder& operator=(RansDecoder&&) = delete;
+
+ void set_stream(const py::array_t&);
+
+ void decode_y(const py::array_t& indexes, const int cdf_group_index);
+ py::array_t decode_and_get_y(const py::array_t& indexes, const int cdf_group_index);
+ void decode_z(const int total_size, const int cdf_group_index, const int start_offset,
+ const int per_channel_size);
+ py::array_t get_decoded_tensor();
+ int add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes,
+ const py::array_t& offsets);
+ void empty_cdf_buffer();
+ void set_use_two_decoders(bool b);
+ bool get_use_two_decoders();
+
+private:
+ std::shared_ptr m_decoder0;
+ std::shared_ptr m_decoder1;
+ bool m_use_two_decoders{ false };
+};
+
+std::vector pmf_to_quantized_cdf(const std::vector& pmf, int precision);
diff --git a/DCVC-RT/src/cpp/py_rans/rans.cpp b/DCVC-RT/src/cpp/py_rans/rans.cpp
new file mode 100644
index 0000000..2c42bac
--- /dev/null
+++ b/DCVC-RT/src/cpp/py_rans/rans.cpp
@@ -0,0 +1,534 @@
+/* Copyright 2020 InterDigital Communications, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Rans64 extensions from:
+ * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/
+ * Unbounded range coding from:
+ * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc
+ **/
+
+#include "rans.h"
+
+#include
+#include
+#include
+
+constexpr uint16_t bypass_precision = 2; /* number of bits in bypass mode */
+constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1;
+
+inline void RansEncPutBits(RansState& r, uint8_t*& ptr, uint32_t val)
+{
+ RansAssert(bypass_precision <= 8);
+ RansAssert(val < (1u << bypass_precision));
+
+ constexpr uint32_t freq = 1 << (SCALE_BITS - bypass_precision);
+ constexpr uint32_t x_max = freq << ENC_RENORM_SHIFT_BITS;
+ while (r >= x_max) {
+ *(--ptr) = static_cast(r & 0xff);
+ r >>= 8;
+ }
+
+ r = (r << bypass_precision) | val;
+}
+
+inline uint32_t RansDecGetBits(RansState& r, uint8_t*& ptr)
+{
+ uint32_t val = r & ((1u << bypass_precision) - 1);
+
+ /* Re-normalize */
+ r = r >> bypass_precision;
+ if (r < RANS_BYTE_L) {
+ r = (r << 8) | *ptr++;
+ RansAssert(r >= RANS_BYTE_L);
+ }
+
+ return val;
+}
+
+RansEncoderLib::RansEncoderLib()
+{
+ _stream = std::make_shared>();
+}
+
+int RansEncoderLib::add_cdf(const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets)
+{
+
+ auto ransSymbols = std::make_shared>>(cdfs->size());
+ for (int i = 0; i < static_cast(cdfs->size()); i++) {
+ const int32_t* cdf = cdfs->at(i).data();
+ std::vector ransSym(cdfs->at(i).size());
+ const int ransSize = static_cast(ransSym.size() - 1);
+ for (int j = 0; j < ransSize; j++) {
+ ransSym[j] = RansSymbol(
+ { static_cast(cdf[j]), static_cast(cdf[j + 1] - cdf[j]) });
+ }
+ ransSymbols->at(i) = ransSym;
+ }
+
+ _ransSymbols.push_back(ransSymbols);
+ _cdfs_sizes.push_back(cdfs_sizes);
+ _offsets.push_back(offsets);
+ return static_cast(_ransSymbols.size()) - 1;
+}
+
+void RansEncoderLib::empty_cdf_buffer()
+{
+ _ransSymbols.clear();
+ _cdfs_sizes.clear();
+ _offsets.clear();
+}
+
+FORCE_INLINE void RansEncoderLib::encode_one_symbol(uint8_t*& ptr, RansState& rans, const int32_t symbol,
+ const int32_t cdf_size, const int32_t offset,
+ const std::vector& ransSymbols)
+{
+ const int32_t max_value = cdf_size - 2;
+ int32_t value = symbol - offset;
+
+ uint32_t raw_val = 0;
+ if (value < 0) {
+ raw_val = -2 * value - 1;
+ value = max_value;
+ } else if (value >= max_value) {
+ raw_val = 2 * (value - max_value);
+ value = max_value;
+ }
+
+ if (value == max_value) {
+ std::vector bypassBins;
+ bypassBins.reserve(20);
+ /* Determine the number of bypasses (in bypass_precision size) needed to
+ * encode the raw value. */
+ int32_t n_bypass = 0;
+ while ((raw_val >> (n_bypass * bypass_precision)) != 0) {
+ ++n_bypass;
+ }
+
+ /* Encode number of bypasses */
+ int32_t val = n_bypass;
+ while (val >= max_bypass_val) {
+ bypassBins.push_back(max_bypass_val);
+ val -= max_bypass_val;
+ }
+ bypassBins.push_back(static_cast(val));
+
+ /* Encode raw value */
+ for (int32_t j = 0; j < n_bypass; ++j) {
+ const int32_t val1 = (raw_val >> (j * bypass_precision)) & max_bypass_val;
+ bypassBins.push_back(static_cast(val1));
+ }
+
+ for (auto it = bypassBins.rbegin(); it < bypassBins.rend(); it++) {
+ RansEncPutBits(rans, ptr, *it);
+ }
+ }
+ RansEncPut(rans, ptr, ransSymbols[value].start, ransSymbols[value].range);
+}
+
+void RansEncoderLib::encode_y(const std::shared_ptr> symbols,
+ const int cdf_group_index)
+{
+ PendingTask p;
+ p.workType = WorkType::EncodeDecodeY;
+ p.symbols_y = symbols;
+ p.cdf_group_index = cdf_group_index;
+ m_pendingEncodingList.push_back(p);
+}
+
+void RansEncoderLib::encode_z(const std::shared_ptr> symbols,
+ const int cdf_group_index, const int start_offset,
+ const int per_channel_size)
+{
+ PendingTask p;
+ p.workType = WorkType::EncodeDecodeZ;
+ p.symbols_z = symbols;
+ p.cdf_group_index = cdf_group_index;
+ p.start_offset = start_offset;
+ p.per_channel_size = per_channel_size;
+ m_pendingEncodingList.push_back(p);
+}
+#include
+FORCE_INLINE void RansEncoderLib::encode_y_internal(uint8_t*& ptr, RansState& rans,
+ const std::shared_ptr> symbols,
+ const int cdf_group_index)
+{
+ // backward loop on symbols from the end;
+ const int16_t* symbols_ptr = symbols->data();
+ const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data();
+ const int32_t* offsets_ptr = _offsets[cdf_group_index]->data();
+ const int symbol_size = static_cast(symbols->size());
+
+ for (int i = symbol_size - 1; i >= 0; i--) {
+ const int32_t combined_symbol = symbols_ptr[i];
+ const int32_t cdf_idx = combined_symbol & 0xff;
+ const int32_t s = combined_symbol >> 8;
+ encode_one_symbol(ptr, rans, s, cdfs_sizes_ptr[cdf_idx], offsets_ptr[cdf_idx],
+ _ransSymbols[cdf_group_index]->at(cdf_idx));
+ }
+}
+
+FORCE_INLINE void RansEncoderLib::encode_z_internal(uint8_t*& ptr, RansState& rans,
+ const std::shared_ptr> symbols,
+ const int cdf_group_index, const int start_offset,
+ const int per_channel_size)
+{
+ // backward loop on symbols from the end;
+ const int8_t* symbols_ptr = symbols->data();
+ const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data();
+ const int32_t* offsets_ptr = _offsets[cdf_group_index]->data();
+ const int symbol_size = static_cast(symbols->size());
+
+ for (int i = symbol_size - 1; i >= 0; i--) {
+ const int32_t cdf_idx = i / per_channel_size + start_offset;
+ encode_one_symbol(ptr, rans, symbols_ptr[i], cdfs_sizes_ptr[cdf_idx], offsets_ptr[cdf_idx],
+ _ransSymbols[cdf_group_index]->at(cdf_idx));
+ }
+}
+
+void RansEncoderLib::flush()
+{
+ RansState rans;
+ RansEncInit(rans);
+
+ int32_t total_symbol_size = 0;
+ for (auto it = m_pendingEncodingList.begin(); it != m_pendingEncodingList.end(); it++) {
+ if (it->workType == WorkType::EncodeDecodeY) {
+ total_symbol_size += static_cast(it->symbols_y->size());
+ } else if (it->workType == WorkType::EncodeDecodeZ) {
+ total_symbol_size += static_cast(it->symbols_z->size());
+ }
+ }
+
+ if (total_symbol_size == 0) {
+ _stream->resize(0);
+ return;
+ }
+
+ uint8_t* output = new uint8_t[total_symbol_size]; // too much space ?
+ uint8_t* ptrEnd = output + total_symbol_size;
+ uint8_t* ptr = ptrEnd;
+ assert(ptr != nullptr);
+
+ for (auto it = m_pendingEncodingList.rbegin(); it != m_pendingEncodingList.rend(); it++) {
+ PendingTask p = *it;
+ if (p.workType == WorkType::EncodeDecodeY) {
+ encode_y_internal(ptr, rans, p.symbols_y, p.cdf_group_index);
+ } else if (p.workType == WorkType::EncodeDecodeZ) {
+ encode_z_internal(ptr, rans, p.symbols_z, p.cdf_group_index, p.start_offset,
+ p.per_channel_size);
+ }
+ }
+
+ RansEncFlush(rans, ptr);
+
+ const int nbytes = static_cast(std::distance(ptr, ptrEnd));
+
+ _stream->resize(nbytes);
+ memcpy(_stream->data(), ptr, nbytes);
+ delete[] output;
+}
+
+std::shared_ptr> RansEncoderLib::get_encoded_stream()
+{
+ return _stream;
+}
+
+void RansEncoderLib::reset()
+{
+ m_pendingEncodingList.clear();
+ _stream->clear();
+}
+
+RansEncoderLibMultiThread::RansEncoderLibMultiThread()
+ : RansEncoderLib()
+ , m_finish(false)
+ , m_result_ready(false)
+ , m_thread(std::thread(&RansEncoderLibMultiThread::worker, this))
+{
+}
+RansEncoderLibMultiThread::~RansEncoderLibMultiThread()
+{
+ {
+ std::lock_guard lk(m_mutex_pending);
+ std::lock_guard lk1(m_mutex_result);
+ m_finish = true;
+ }
+ m_cv_pending.notify_one();
+ m_cv_result.notify_one();
+ m_thread.join();
+}
+
+void RansEncoderLibMultiThread::flush()
+{
+ PendingTask p;
+ p.workType = WorkType::Flush;
+ {
+ std::unique_lock lk(m_mutex_pending);
+ m_pending.push_back(p);
+ }
+ m_cv_pending.notify_one();
+}
+
+std::shared_ptr> RansEncoderLibMultiThread::get_encoded_stream()
+{
+ std::unique_lock lk(m_mutex_result);
+ m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; });
+ return RansEncoderLib::get_encoded_stream();
+}
+
+void RansEncoderLibMultiThread::reset()
+{
+ RansEncoderLib::reset();
+ std::lock_guard lk(m_mutex_result);
+ m_result_ready = false;
+}
+
+void RansEncoderLibMultiThread::worker()
+{
+ while (!m_finish) {
+ std::unique_lock lk(m_mutex_pending);
+ m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; });
+ if (m_finish) {
+ lk.unlock();
+ break;
+ }
+ if (m_pending.size() == 0) {
+ lk.unlock();
+ // std::cout << "contine in worker" << std::endl;
+ continue;
+ }
+ while (m_pending.size() > 0) {
+ auto p = m_pending.front();
+ m_pending.pop_front();
+ lk.unlock();
+ if (p.workType == WorkType::Flush) {
+ RansEncoderLib::flush();
+ {
+ std::lock_guard lk_result(m_mutex_result);
+ m_result_ready = true;
+ }
+ m_cv_result.notify_one();
+ }
+ lk.lock();
+ }
+ lk.unlock();
+ }
+}
+
+void RansDecoderLib::set_stream(const std::shared_ptr> encoded)
+{
+ _stream = encoded;
+ _ptr8 = (uint8_t*)(_stream->data());
+ RansDecInit(_rans, _ptr8);
+}
+
+int RansDecoderLib::add_cdf(const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets)
+{
+ _cdfs.push_back(cdfs);
+ _cdfs_sizes.push_back(cdfs_sizes);
+ _offsets.push_back(offsets);
+ return static_cast(_cdfs.size()) - 1;
+}
+
+void RansDecoderLib::empty_cdf_buffer()
+{
+ _cdfs.clear();
+ _cdfs_sizes.clear();
+ _offsets.clear();
+}
+
+FORCE_INLINE int8_t RansDecoderLib::decode_one_symbol(const int32_t* cdf, const int32_t cdf_size,
+ const int32_t offset)
+{
+ const int32_t max_value = cdf_size - 2;
+ const int32_t cum_freq = static_cast(RansDecGet(_rans));
+
+ int s = 1;
+ while (cdf[s++] <= cum_freq) {
+ }
+ s -= 2;
+
+ RansDecAdvance(_rans, _ptr8, cdf[s], cdf[s + 1] - cdf[s]);
+
+ int32_t value = static_cast(s);
+
+ if (value == max_value) {
+ /* Bypass decoding mode */
+ int32_t val = RansDecGetBits(_rans, _ptr8);
+ int32_t n_bypass = val;
+
+ while (val == max_bypass_val) {
+ val = RansDecGetBits(_rans, _ptr8);
+ n_bypass += val;
+ }
+
+ int32_t raw_val = 0;
+ for (int j = 0; j < n_bypass; ++j) {
+ val = RansDecGetBits(_rans, _ptr8);
+ raw_val |= val << (j * bypass_precision);
+ }
+ value = raw_val >> 1;
+ if (raw_val & 1) {
+ value = -value - 1;
+ } else {
+ value += max_value;
+ }
+ }
+
+ return static_cast(value + offset);
+}
+
+void RansDecoderLib::decode_y(const std::shared_ptr> indexes,
+ const int cdf_group_index)
+{
+ int index_size = static_cast(indexes->size());
+ m_decoded = std::make_shared>(index_size);
+
+ int8_t* outout_ptr = m_decoded->data();
+ const uint8_t* indexes_ptr = indexes->data();
+ const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data();
+ const int32_t* offsets_ptr = _offsets[cdf_group_index]->data();
+ const auto& cdfs = _cdfs[cdf_group_index];
+ for (int i = 0; i < index_size; ++i) {
+ const int32_t cdf_idx = indexes_ptr[i];
+ outout_ptr[i] = decode_one_symbol(cdfs->at(cdf_idx).data(), cdfs_sizes_ptr[cdf_idx],
+ offsets_ptr[cdf_idx]);
+ }
+}
+
+void RansDecoderLib::decode_z(const int total_size, const int cdf_group_index,
+ const int start_offset, const int per_channel_size)
+{
+ m_decoded = std::make_shared>(total_size);
+
+ int8_t* outout_ptr = m_decoded->data();
+ const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data();
+ const int32_t* offsets_ptr = _offsets[cdf_group_index]->data();
+ const auto& cdfs = _cdfs[cdf_group_index];
+ for (int i = 0; i < total_size; ++i) {
+ const int32_t cdf_idx = i / per_channel_size + start_offset;
+ outout_ptr[i] = decode_one_symbol(cdfs->at(cdf_idx).data(), cdfs_sizes_ptr[cdf_idx],
+ offsets_ptr[cdf_idx]);
+ }
+}
+
+std::shared_ptr> RansDecoderLib::get_decoded_tensor()
+{
+ return m_decoded;
+}
+
+RansDecoderLibMultiThread::RansDecoderLibMultiThread()
+ : RansDecoderLib()
+ , m_finish(false)
+ , m_result_ready(false)
+ , m_thread(std::thread(&RansDecoderLibMultiThread::worker, this))
+{
+}
+
+RansDecoderLibMultiThread::~RansDecoderLibMultiThread()
+{
+ {
+ std::lock_guard lk(m_mutex_pending);
+ std::lock_guard lk1(m_mutex_result);
+ m_finish = true;
+ }
+ m_cv_pending.notify_one();
+ m_cv_result.notify_one();
+ m_thread.join();
+}
+
+void RansDecoderLibMultiThread::decode_y(const std::shared_ptr> indexes,
+ const int cdf_group_index)
+{
+ {
+ std::lock_guard lk(m_mutex_result);
+ m_result_ready = false;
+ }
+ PendingTask p;
+ p.workType = WorkType::EncodeDecodeY;
+ p.indexes = indexes;
+ p.cdf_group_index = cdf_group_index;
+ {
+ std::unique_lock lk(m_mutex_pending);
+ m_pending.push_back(p);
+ }
+ m_cv_pending.notify_one();
+}
+
+void RansDecoderLibMultiThread::decode_z(const int total_size, const int cdf_group_index,
+ const int start_offset, const int per_channel_size)
+{
+ {
+ std::lock_guard lk(m_mutex_result);
+ m_result_ready = false;
+ }
+ PendingTask p;
+ p.workType = WorkType::EncodeDecodeZ;
+ p.total_size = total_size;
+ p.cdf_group_index = cdf_group_index;
+ p.start_offset = start_offset;
+ p.per_channel_size = per_channel_size;
+ {
+ std::unique_lock lk(m_mutex_pending);
+ m_pending.push_back(p);
+ }
+ m_cv_pending.notify_one();
+}
+
+std::shared_ptr> RansDecoderLibMultiThread::get_decoded_tensor()
+{
+ std::unique_lock lk(m_mutex_result);
+ m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; });
+ return RansDecoderLib::get_decoded_tensor();
+}
+
+void RansDecoderLibMultiThread::worker()
+{
+ while (!m_finish) {
+ std::unique_lock lk(m_mutex_pending);
+ m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; });
+ if (m_finish) {
+ lk.unlock();
+ break;
+ }
+ if (m_pending.size() == 0) {
+ lk.unlock();
+ // std::cout << "contine in worker" << std::endl;
+ continue;
+ }
+ while (m_pending.size() > 0) {
+ auto p = m_pending.front();
+ m_pending.pop_front();
+ lk.unlock();
+ if (p.workType == WorkType::EncodeDecodeY) {
+ RansDecoderLib::decode_y(p.indexes, p.cdf_group_index);
+ } else if (p.workType == WorkType::EncodeDecodeZ) {
+ RansDecoderLib::decode_z(p.total_size, p.cdf_group_index, p.start_offset,
+ p.per_channel_size);
+ }
+ {
+ std::lock_guard lk_result(m_mutex_result);
+ m_result_ready = true;
+ }
+ m_cv_result.notify_one();
+ lk.lock();
+ }
+ lk.unlock();
+ }
+}
diff --git a/DCVC-RT/src/cpp/py_rans/rans.h b/DCVC-RT/src/cpp/py_rans/rans.h
new file mode 100644
index 0000000..995dfee
--- /dev/null
+++ b/DCVC-RT/src/cpp/py_rans/rans.h
@@ -0,0 +1,201 @@
+/* Copyright 2020 InterDigital Communications, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#ifdef __GNUC__
+ #pragma GCC diagnostic push
+ #pragma GCC diagnostic ignored "-Wpedantic"
+ #pragma GCC diagnostic ignored "-Wsign-compare"
+#endif
+
+#ifdef _MSC_VER
+ #pragma warning(disable : 4244)
+#endif
+
+#include "rans_byte.h"
+
+#ifdef _MSC_VER
+ #pragma warning(default : 4244)
+#endif
+
+#ifdef __GNUC__
+ #pragma GCC diagnostic pop
+#endif
+
+#ifdef _MSC_VER
+ #define FORCE_INLINE __forceinline
+#endif
+
+#ifdef __GNUC__
+ #define FORCE_INLINE __attribute__((always_inline)) inline
+#endif
+
+struct RansSymbol {
+ uint16_t start;
+ uint16_t range; // range for normal coding and 0 for bypass coding
+};
+
+enum class WorkType {
+ EncodeDecodeY,
+ EncodeDecodeZ,
+ Flush,
+};
+
+struct PendingTask {
+ WorkType workType;
+ std::shared_ptr> symbols_y;
+ std::shared_ptr> symbols_z;
+ std::shared_ptr> indexes;
+ int total_size{ 0 };
+ int cdf_group_index{ 0 };
+ int start_offset{ 0 };
+ int per_channel_size{ 0 };
+};
+
+/* NOTE: Warning, we buffer everything for now... In case of large files we
+ * should split the bitstream into chunks... Or for a memory-bounded encoder
+ **/
+class RansEncoderLib {
+public:
+ RansEncoderLib();
+ virtual ~RansEncoderLib() = default;
+
+ RansEncoderLib(const RansEncoderLib&) = delete;
+ RansEncoderLib(RansEncoderLib&&) = delete;
+ RansEncoderLib& operator=(const RansEncoderLib&) = delete;
+ RansEncoderLib& operator=(RansEncoderLib&&) = delete;
+
+ void encode_y(const std::shared_ptr> symbols, const int cdf_group_index);
+ void encode_z(const std::shared_ptr> symbols, const int cdf_group_index,
+ const int start_offset, const int per_channel_size);
+
+ FORCE_INLINE void encode_y_internal(uint8_t*& ptr, RansState& rans,
+ const std::shared_ptr> symbols,
+ const int cdf_group_index);
+ FORCE_INLINE void encode_z_internal(uint8_t*& ptr, RansState& rans,
+ const std::shared_ptr> symbols,
+ const int cdf_group_index, const int start_offset,
+ const int per_channel_size);
+ FORCE_INLINE void encode_one_symbol(uint8_t*& ptr, RansState& rans, const int32_t symbol,
+ const int32_t cdf_size, const int32_t offset,
+ const std::vector& ransSymbols);
+ virtual void flush();
+ virtual std::shared_ptr> get_encoded_stream();
+ virtual void reset();
+ virtual int add_cdf(const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets);
+ virtual void empty_cdf_buffer();
+
+private:
+ std::shared_ptr> _stream;
+
+ std::vector>>> _ransSymbols;
+ std::vector>> _cdfs_sizes;
+ std::vector>> _offsets;
+
+ std::list m_pendingEncodingList;
+};
+
+class RansEncoderLibMultiThread : public RansEncoderLib {
+public:
+ RansEncoderLibMultiThread();
+ virtual ~RansEncoderLibMultiThread();
+ virtual void flush() override;
+ virtual std::shared_ptr> get_encoded_stream() override;
+ virtual void reset() override;
+
+ void worker();
+
+private:
+ bool m_finish;
+ bool m_result_ready;
+ std::thread m_thread;
+ std::mutex m_mutex_result;
+ std::mutex m_mutex_pending;
+ std::condition_variable m_cv_pending;
+ std::condition_variable m_cv_result;
+ std::list m_pending;
+};
+
+class RansDecoderLib {
+public:
+ RansDecoderLib() {}
+ virtual ~RansDecoderLib() = default;
+
+ RansDecoderLib(const RansDecoderLib&) = delete;
+ RansDecoderLib(RansDecoderLib&&) = delete;
+ RansDecoderLib& operator=(const RansDecoderLib&) = delete;
+ RansDecoderLib& operator=(RansDecoderLib&&) = delete;
+
+ virtual void set_stream(const std::shared_ptr> encoded);
+
+ FORCE_INLINE int8_t decode_one_symbol(const int32_t* cdf, const int32_t cdf_size,
+ const int32_t offset);
+
+ virtual void decode_y(const std::shared_ptr> indexes,
+ const int cdf_group_index);
+ virtual void decode_z(const int total_size, const int cdf_group_index, const int start_offset,
+ const int per_channel_size);
+
+ virtual std::shared_ptr> get_decoded_tensor();
+
+ virtual int add_cdf(const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets);
+ virtual void empty_cdf_buffer();
+
+private:
+ RansState _rans;
+ uint8_t* _ptr8;
+ std::shared_ptr> _stream;
+ std::shared_ptr> m_decoded;
+
+ std::vector>>> _cdfs;
+ std::vector>> _cdfs_sizes;
+ std::vector>> _offsets;
+};
+
+class RansDecoderLibMultiThread : public RansDecoderLib {
+public:
+ RansDecoderLibMultiThread();
+ virtual ~RansDecoderLibMultiThread();
+
+ virtual void decode_y(const std::shared_ptr> indexes,
+ const int cdf_group_index) override;
+
+ virtual void decode_z(const int total_size, const int cdf_group_index, const int start_offset,
+ const int per_channel_size) override;
+
+ virtual std::shared_ptr> get_decoded_tensor() override;
+
+ void worker();
+
+private:
+ bool m_finish;
+ bool m_result_ready;
+ std::thread m_thread;
+ std::mutex m_mutex_result;
+ std::mutex m_mutex_pending;
+ std::condition_variable m_cv_pending;
+ std::condition_variable m_cv_result;
+ std::list m_pending;
+};
\ No newline at end of file
diff --git a/DCVC-RT/src/cpp/py_rans/rans_byte.h b/DCVC-RT/src/cpp/py_rans/rans_byte.h
new file mode 100644
index 0000000..9c77e31
--- /dev/null
+++ b/DCVC-RT/src/cpp/py_rans/rans_byte.h
@@ -0,0 +1,141 @@
+// The code is from https://github.com/rygorous/ryg_rans
+// The original lisence is below.
+
+// To the extent possible under law, Fabian Giesen has waived all
+// copyright and related or neighboring rights to ryg_rans, as
+// per the terms of the CC0 license:
+
+// https://creativecommons.org/publicdomain/zero/1.0
+
+// This work is published from the United States.
+
+// Simple byte-aligned rANS encoder/decoder - public domain - Fabian 'ryg'
+// Giesen 2014
+//
+// Not intended to be "industrial strength"; just meant to illustrate the
+// general idea.
+
+#pragma once
+
+#include
+
+#ifdef assert
+ #define RansAssert assert
+#else
+ #define RansAssert(x)
+#endif
+
+// READ ME FIRST:
+//
+// This is designed like a typical arithmetic coder API, but there's three
+// twists you absolutely should be aware of before you start hacking:
+//
+// 1. You need to encode data in *reverse* - last symbol first. rANS works
+// like a stack: last in, first out.
+// 2. Likewise, the encoder outputs bytes *in reverse* - that is, you give
+// it a pointer to the *end* of your buffer (exclusive), and it will
+// slowly move towards the beginning as more bytes are emitted.
+// 3. Unlike basically any other entropy coder implementation you might
+// have used, you can interleave data from multiple independent rANS
+// encoders into the same bytestream without any extra signaling;
+// you can also just write some bytes by yourself in the middle if
+// you want to. This is in addition to the usual arithmetic encoder
+// property of being able to switch models on the fly. Writing raw
+// bytes can be useful when you have some data that you know is
+// incompressible, and is cheaper than going through the rANS encode
+// function. Using multiple rANS coders on the same byte stream wastes
+// a few bytes compared to using just one, but execution of two
+// independent encoders can happen in parallel on superscalar and
+// Out-of-Order CPUs, so this can be *much* faster in tight decoding
+// loops.
+//
+// This is why all the rANS functions take the write pointer as an
+// argument instead of just storing it in some context struct.
+
+// --------------------------------------------------------------------------
+
+// L ('l' in the paper) is the lower bound of our normalization interval.
+// Between this and our byte-aligned emission, we use 31 (not 32!) bits.
+// This is done intentionally because exact reciprocals for 31-bit uints
+// fit in 32-bit uints: this permits some optimizations during encoding.
+constexpr int SCALE_BITS = 16;
+constexpr int RANS_SHIFT_BITS = 23;
+constexpr uint32_t RANS_BYTE_L = (1u << RANS_SHIFT_BITS);
+constexpr int ENC_RENORM_SHIFT_BITS = RANS_SHIFT_BITS - SCALE_BITS + 8;
+constexpr int DEC_MASK = ((1u << SCALE_BITS) - 1);
+
+// State for a rANS encoder. Yep, that's all there is to it.
+typedef uint32_t RansState;
+
+// Initialize a rANS encoder.
+static inline void RansEncInit(RansState& r)
+{
+ r = RANS_BYTE_L;
+}
+
+// Renormalize the encoder. Internal function.
+static inline void RansEncRenorm(RansState& x, uint8_t*& ptr, uint32_t freq)
+{
+ const uint32_t x_max = freq << ENC_RENORM_SHIFT_BITS;
+ while (x >= x_max) {
+ *(--ptr) = static_cast(x & 0xff);
+ x >>= 8;
+ }
+}
+
+// Encodes a single symbol with range start "start" and frequency "freq".
+// All frequencies are assumed to sum to "1 << scale_bits", and the
+// resulting bytes get written to ptr (which is updated).
+//
+// NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
+// beginning to end! Likewise, the output bytestream is written *backwards*:
+// ptr starts pointing at the end of the output buffer and keeps decrementing.
+static inline void RansEncPut(RansState& r, uint8_t*& ptr, uint32_t start, uint32_t freq)
+{
+ // renormalize
+ RansEncRenorm(r, ptr, freq);
+
+ // x = C(s,x)
+ r = ((r / freq) << SCALE_BITS) + (r % freq) + start;
+}
+
+// Flushes the rANS encoder.
+static inline void RansEncFlush(const RansState& r, uint8_t*& ptr)
+{
+ ptr -= 4;
+ ptr[0] = (uint8_t)(r >> 0);
+ ptr[1] = (uint8_t)(r >> 8);
+ ptr[2] = (uint8_t)(r >> 16);
+ ptr[3] = (uint8_t)(r >> 24);
+}
+
+// Initializes a rANS decoder.
+// Unlike the encoder, the decoder works forwards as you'd expect.
+static inline void RansDecInit(RansState& r, uint8_t*& ptr)
+{
+ r = (*ptr++) << 0;
+ r |= (*ptr++) << 8;
+ r |= (*ptr++) << 16;
+ r |= (*ptr++) << 24;
+}
+
+// Returns the current cumulative frequency (map it to a symbol yourself!)
+static inline uint32_t RansDecGet(RansState& r)
+{
+ return r & DEC_MASK;
+}
+
+// Advances in the bit stream by "popping" a single symbol with range start
+// "start" and frequency "freq". All frequencies are assumed to sum to "1 <<
+// scale_bits", and the resulting bytes get written to ptr (which is updated).
+static inline void RansDecAdvance(RansState& r, uint8_t*& ptr, uint32_t start, uint32_t freq)
+{
+
+ // s, x = D(x)
+ r = freq * (r >> SCALE_BITS) + (r & DEC_MASK) - start;
+
+ // renormalize
+ while (r < RANS_BYTE_L) {
+ r = (r << 8) | *ptr++;
+ }
+}
diff --git a/DCVC-RT/src/cpp/setup.py b/DCVC-RT/src/cpp/setup.py
new file mode 100644
index 0000000..7e1b036
--- /dev/null
+++ b/DCVC-RT/src/cpp/setup.py
@@ -0,0 +1,31 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import glob
+import sys
+from setuptools import setup
+from pybind11.setup_helpers import Pybind11Extension, build_ext
+
+
+if sys.platform == "win32":
+ extra_compile_args = ['/std:c++17', '/O2', '/W4', '/WX', '/wd4100']
+ extra_link_args = []
+else:
+ extra_compile_args = ['-std=c++17', '-O3', '-fPIC', '-Wall', '-Wextra', '-Werror']
+ extra_link_args = []
+
+
+setup(
+ name="MLCodec_extensions_cpp",
+ ext_modules=[
+ Pybind11Extension(
+ name='MLCodec_extensions_cpp',
+ sources=glob.glob('py_rans/*.cpp'),
+ extra_compile_args=extra_compile_args,
+ extra_link_args=extra_link_args,
+ ),
+ ],
+ cmdclass={"build_ext": build_ext},
+ zip_safe=False,
+ python_requires=">=3.12",
+)
diff --git a/DCVC-RT/src/layers/cuda_inference.py b/DCVC-RT/src/layers/cuda_inference.py
new file mode 100644
index 0000000..dee9836
--- /dev/null
+++ b/DCVC-RT/src/layers/cuda_inference.py
@@ -0,0 +1,203 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+
+import torch
+import torch.nn.functional as F
+
+
+CUSTOMIZED_CUDA_INFERENCE = False
+try:
+ from inference_extensions_cuda import process_with_mask_cuda, combine_for_reading_2x_cuda, \
+ restore_y_2x_cuda, restore_y_4x_cuda, build_index_dec_cuda, \
+ round_and_to_int8_cuda, clamp_reciprocal_with_quant_cuda, bias_quant_cuda, \
+ add_and_multiply_cuda, bias_pixel_shuffle_8_cuda, replicate_pad_cuda, \
+ build_index_enc_cuda, DepthConvProxy, SubpelConv2xProxy # noqa: F401
+ CUSTOMIZED_CUDA_INFERENCE = True
+except Exception: # pylint: disable=W0718
+ pass
+
+
+if not CUSTOMIZED_CUDA_INFERENCE and 'SUPPRESS_CUSTOM_KERNEL_WARNING' not in os.environ:
+ print("cannot import cuda implementation for inference, fallback to pytorch.")
+
+
+def round_and_to_int8(z):
+ if CUSTOMIZED_CUDA_INFERENCE and z.is_cuda:
+ z_int8 = round_and_to_int8_cuda(z)
+ return z, z_int8
+
+ z_hat = torch.clamp(torch.round(z), -128., 127.)
+ z_hat_write = z_hat.to(dtype=torch.int8)
+ return z_hat, z_hat_write
+
+
+def clamp_reciprocal_with_quant(q_dec, y, min_val):
+ if CUSTOMIZED_CUDA_INFERENCE and q_dec.is_cuda:
+ # q_dec is not inplace modified at decoder side
+ q_dec = clamp_reciprocal_with_quant_cuda(q_dec, y, min_val)
+ return q_dec, y
+
+ q_dec = torch.clamp_min(q_dec, min_val)
+ q_enc = torch.reciprocal(q_dec)
+ y = y * q_enc
+ return q_dec, y
+
+
+def add_and_multiply(y_hat_0, y_hat_1, q_dec):
+ if CUSTOMIZED_CUDA_INFERENCE and y_hat_0.is_cuda:
+ add_and_multiply_cuda(y_hat_0, y_hat_1, q_dec)
+ return y_hat_0
+
+ y_hat = y_hat_0 + y_hat_1
+ y_hat = y_hat * q_dec
+ return y_hat
+
+
+def process_with_mask(y, scales, means, mask, force_zero_thres):
+ if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda:
+ thres = force_zero_thres if force_zero_thres is not None else -1.
+ return process_with_mask_cuda(y, scales, means, mask, thres)
+
+ scales_hat = scales * mask
+ means_hat = means * mask
+
+ y_res = (y - means_hat) * mask
+ y_q = torch.round(y_res)
+ if force_zero_thres is not None:
+ cond = scales_hat > force_zero_thres
+ y_q = y_q * cond
+ y_q = torch.clamp(y_q, -128., 127.)
+ y_hat = y_q + means_hat
+
+ return y_res, y_q, y_hat, scales_hat
+
+
+def combine_for_reading_2x(x, mask, inplace=False):
+ if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda and x.is_contiguous():
+ B, C, H, W = x.shape
+ if inplace:
+ out = x[:, :C // 2, :, :]
+ else:
+ out = torch.empty((B, C // 2, H, W), dtype=x.dtype, layout=x.layout, device=x.device)
+ combine_for_reading_2x_cuda(out, x, mask)
+ return out
+
+ x = x * mask
+ x0, x1 = x.chunk(2, 1)
+ return x0 + x1
+
+
+def restore_y_2x(y, means, mask):
+ if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda and y.is_contiguous():
+ out = torch.empty_like(means)
+ restore_y_2x_cuda(out, y, means, mask)
+ return out
+
+ return (torch.cat((y, y), dim=1) + means) * mask
+
+
+def restore_y_2x_with_cat_after(y, means, mask, to_cat):
+ if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda and y.is_contiguous():
+ B, C1, H, W = means.shape
+ C2 = to_cat.shape[1]
+ out = torch.empty((B, C1 + C2, H, W), dtype=means.dtype, layout=means.layout,
+ device=means.device)
+ restore_y_2x_cuda(out[:, :C1, :, :], y, means, mask)
+ out[:, C1:, :, :] = to_cat
+ return out[:, :C1, :, :], out
+
+ out = (torch.cat((y, y), dim=1) + means) * mask
+ return out, torch.cat((out, to_cat), dim=1)
+
+
+def restore_y_4x(y, means, mask):
+ if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda and y.is_contiguous():
+ out = torch.empty_like(means)
+ restore_y_4x_cuda(out, y, means, mask)
+ return out
+
+ return (torch.cat((y, y, y, y), dim=1) + means) * mask
+
+
+def build_index_dec(scales, scale_min, scale_max, log_scale_min, log_step_recip, skip_thres=None):
+ if CUSTOMIZED_CUDA_INFERENCE and scales.is_cuda:
+ out = torch.empty_like(scales, dtype=torch.uint8)
+ skip_cond = None
+ if skip_thres is not None:
+ skip_cond = torch.empty_like(scales, dtype=torch.bool)
+ else:
+ skip_thres = -1.
+
+ build_index_dec_cuda(out, skip_cond, scales, scale_min, scale_max, log_scale_min,
+ log_step_recip, skip_thres)
+ return out, skip_cond
+
+ skip_cond = None
+ if skip_thres is not None:
+ skip_cond = scales > skip_thres
+ scales = scales.clamp_(scale_min, scale_max)
+ indexes = (torch.log(scales) - log_scale_min) * log_step_recip
+ indexes = indexes.to(dtype=torch.uint8)
+ return indexes, skip_cond
+
+
+def build_index_enc(symbols, scales, scale_min, scale_max, log_scale_min,
+ log_step_recip, skip_thres=None):
+ if CUSTOMIZED_CUDA_INFERENCE and scales.is_cuda:
+ out = torch.empty_like(scales, dtype=torch.int16)
+ skip_cond = None
+ if skip_thres is not None:
+ skip_cond = torch.empty_like(scales, dtype=torch.bool)
+ else:
+ skip_thres = -1.
+
+ build_index_enc_cuda(out, skip_cond, symbols, scales, scale_min, scale_max, log_scale_min,
+ log_step_recip, skip_thres)
+
+ out = out[skip_cond]
+ return out
+
+ scales = scales.clamp_(scale_min, scale_max)
+ indexes = (torch.log(scales) - log_scale_min) * log_step_recip
+ indexes = indexes.to(dtype=torch.uint8)
+ symbols = symbols.to(dtype=torch.int16)
+ out = (symbols << 8) + indexes
+ out = out.to(dtype=torch.int16)
+ if skip_thres is not None:
+ skip_cond = scales > skip_thres
+ out = out[skip_cond]
+ return out
+
+
+def replicate_pad(x, pad_b, pad_r):
+ if pad_b == 0 and pad_r == 0:
+ return x
+ if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda:
+ return replicate_pad_cuda(x, pad_b, pad_r)
+ return F.pad(x, (0, pad_r, 0, pad_b), mode="replicate")
+
+
+def bias_pixel_shuffle_8(x, bias):
+ if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda:
+ B, C, H, W = x.shape
+ assert B == 1
+ out = torch.empty((B, 3, H * 8, W * 8), dtype=x.dtype, device=x.device, layout=x.layout)
+ bias_pixel_shuffle_8_cuda(out, x, bias, C, H * W, W, True)
+ return out
+
+ out = x + bias[None, :, None, None]
+ out = F.pixel_shuffle(out, 8)
+ out = torch.clamp(out, 0., 1.)
+ return out
+
+
+def bias_quant(x, bias, quant_step):
+ if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda:
+ bias_quant_cuda(x, bias, quant_step)
+ return x
+
+ out = x + bias[None, :, None, None]
+ out = out * quant_step
+ return out
diff --git a/DCVC-RT/src/layers/extensions/inference/bind.cpp b/DCVC-RT/src/layers/extensions/inference/bind.cpp
new file mode 100644
index 0000000..0bc8e24
--- /dev/null
+++ b/DCVC-RT/src/layers/extensions/inference/bind.cpp
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include "def.h"
+#include
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("process_with_mask_cuda", &process_with_mask_cuda);
+ m.def("combine_for_reading_2x_cuda", &combine_for_reading_2x_cuda);
+ m.def("restore_y_2x_cuda", &restore_y_2x_cuda);
+ m.def("restore_y_4x_cuda", &restore_y_4x_cuda);
+ m.def("build_index_dec_cuda", &build_index_dec_cuda);
+ m.def("build_index_enc_cuda", &build_index_enc_cuda);
+ m.def("bias_quant_cuda", &bias_quant_cuda);
+ m.def("round_and_to_int8_cuda", &round_and_to_int8_cuda);
+ m.def("clamp_reciprocal_with_quant_cuda", &clamp_reciprocal_with_quant_cuda);
+ m.def("add_and_multiply_cuda", &add_and_multiply_cuda);
+ m.def("bias_pixel_shuffle_8_cuda", &bias_pixel_shuffle_8_cuda);
+ m.def("replicate_pad_cuda", &replicate_pad_cuda);
+ m.def("bias_wsilu_depthwise_conv2d_cuda", &bias_wsilu_depthwise_conv2d_cuda);
+
+ py::class_(m, "DepthConvProxy")
+ .def(py::init<>())
+ .def("set_param", &DepthConvProxy::set_param)
+ .def("set_param_with_adaptor", &DepthConvProxy::set_param_with_adaptor)
+ .def("forward", &DepthConvProxy::forward)
+ .def("forward_with_quant_step", &DepthConvProxy::forward_with_quant_step)
+ .def("forward_with_cat", &DepthConvProxy::forward_with_cat);
+
+ py::class_(m, "SubpelConv2xProxy")
+ .def(py::init<>())
+ .def("set_param", &SubpelConv2xProxy::set_param)
+ .def("forward", &SubpelConv2xProxy::forward)
+ .def("forward_with_cat", &SubpelConv2xProxy::forward_with_cat);
+}
diff --git a/DCVC-RT/src/layers/extensions/inference/common.h b/DCVC-RT/src/layers/extensions/inference/common.h
new file mode 100644
index 0000000..7a8bfed
--- /dev/null
+++ b/DCVC-RT/src/layers/extensions/inference/common.h
@@ -0,0 +1,319 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#pragma once
+#include
+#include
+#include
+#include
+#include
+
+// T maybe vector type, and may be different from t.dtype
+template
+struct GPUTensor1D {
+ GPUTensor1D(torch::Tensor& t) : ptr(static_cast(t.data_ptr())) {}
+ GPUTensor1D(const torch::Tensor& t) : ptr(static_cast(t.data_ptr())) {}
+ GPUTensor1D(T* t) : ptr(static_cast(t)) { assert(t == nullptr); }
+
+ __device__ T& operator[](int idx) { return ptr[idx]; }
+ __device__ T& operator[](int idx) const { return ptr[idx]; }
+
+ T* __restrict__ const ptr;
+};
+
+template
+using Packed4DTensorAccessor32 = torch::PackedTensorAccessor32;
+template
+using Packed1DTensorAccessor32 = torch::PackedTensorAccessor32;
+
+struct __align__(8) Half4
+{
+ c10::Half x;
+ c10::Half y;
+ c10::Half z;
+ c10::Half w;
+};
+
+struct __align__(4) bool4
+{
+ bool x;
+ bool y;
+ bool z;
+ bool w;
+};
+
+__forceinline__ __device__ float4 make_vec4(const float& x, const float& y, const float& z,
+ const float& w)
+{
+ return make_float4(x, y, z, w);
+}
+
+__forceinline__ __device__ Half4 make_vec4(const c10::Half& x, const c10::Half& y,
+ const c10::Half& z, const c10::Half& w)
+{
+ Half4 t;
+ t.x = x;
+ t.y = y;
+ t.z = z;
+ t.w = w;
+ return t;
+}
+
+__forceinline__ __device__ Half4 make_Half4(const c10::Half& x, const c10::Half& y,
+ const c10::Half& z, const c10::Half& w)
+{
+ Half4 t;
+ t.x = x;
+ t.y = y;
+ t.z = z;
+ t.w = w;
+ return t;
+}
+
+__forceinline__ __device__ bool4 make_vec4(const bool& x, const bool& y, const bool& z, const bool& w)
+{
+ bool4 t;
+ t.x = x;
+ t.y = y;
+ t.z = z;
+ t.w = w;
+ return t;
+}
+
+__forceinline__ __device__ c10::Half round(const c10::Half& a)
+{
+ return static_cast(__half2int_rn(a));
+}
+
+template
+__forceinline__ __device__ T round(const T& a)
+{
+ return make_vec4(round(a.x), round(a.y), round(a.z), round(a.w));
+}
+
+__forceinline__ __device__ int8_t to_int8(const float& a)
+{
+ return static_cast(a);
+}
+
+__forceinline__ __device__ int8_t to_int8(const c10::Half& a)
+{
+ return static_cast(a);
+}
+
+template
+__forceinline__ __device__ char4 to_int8(const T& a)
+{
+ return make_char4(to_int8(a.x), to_int8(a.y), to_int8(a.z), to_int8(a.w));
+}
+
+__forceinline__ __device__ uint8_t to_uint8(const float& a)
+{
+ return static_cast(a);
+}
+
+__forceinline__ __device__ uint8_t to_uint8(const c10::Half& a)
+{
+ return static_cast(__half2uint_rd(a));
+}
+
+template
+__forceinline__ __device__ uchar4 to_uint8(const T& a)
+{
+ return make_uchar4(to_uint8(a.x), to_uint8(a.y), to_uint8(a.z), to_uint8(a.w));
+}
+
+__forceinline__ __device__ int16_t to_int16(const float& a)
+{
+ return static_cast(a);
+}
+
+__forceinline__ __device__ int16_t to_int16(const c10::Half& a)
+{
+ return static_cast(__half2int_rd(a));
+}
+
+template
+__forceinline__ __device__ short4 to_int16(const T& a)
+{
+ return make_short4(to_int16(a.x), to_int16(a.y), to_int16(a.z), to_int16(a.w));
+}
+
+__forceinline__ __device__ short4 operator<<(const short4& a, const int b)
+{
+ return make_short4(a.x << b, a.y << b, a.z << b, a.w << b);
+}
+
+__forceinline__ __device__ c10::Half min(const c10::Half& a, const c10::Half& b)
+{
+ return __hmin(a, b);
+}
+__forceinline__ __device__ c10::Half max(const c10::Half& a, const c10::Half& b)
+{
+ return __hmax(a, b);
+}
+
+__forceinline__ __device__ bool operator>(const c10::Half& a, const c10::Half& b)
+{
+ return __hgt(a, b);
+}
+
+__forceinline__ __device__ bool operator<(const c10::Half& a, const c10::Half& b)
+{
+ return __hlt(a, b);
+}
+
+__forceinline__ __device__ c10::Half log(const c10::Half& a)
+{
+ return hlog(a);
+}
+
+__forceinline__ __device__ short4 operator+(const short4& a, const short4& b)
+{
+ return make_short4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+}
+
+__forceinline__ __device__ float4 operator+(const float4& a, const float4& b)
+{
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+}
+
+__forceinline__ __device__ Half4 operator+(const Half4& a, const Half4& b)
+{
+ return make_Half4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+}
+
+template
+__forceinline__ __device__ T operator-(const T& a, const T& b)
+{
+ return make_vec4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
+}
+
+__forceinline__ __device__ float4 operator-(const float4& a, const float& b)
+{
+ return make_vec4(a.x - b, a.y - b, a.z - b, a.w - b);
+}
+
+__forceinline__ __device__ Half4 operator-(const Half4& a, const c10::Half& b)
+{
+ return make_vec4(a.x - b, a.y - b, a.z - b, a.w - b);
+}
+
+__forceinline__ __device__ c10::Half operator*(const c10::Half& a, const bool b)
+{
+ return b ? a : static_cast(0.f);
+}
+
+template
+__forceinline__ __device__ T operator*(const T& a, const T& b)
+{
+ return make_vec4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
+}
+
+template
+__forceinline__ __device__ T operator*(const T& a, const bool4& b)
+{
+ return make_vec4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
+}
+
+template
+__forceinline__ __device__ T1 operator*(const T1& a, const T2& b)
+{
+ return make_vec4(a.x * b, a.y * b, a.z * b, a.w * b);
+}
+
+template
+__forceinline__ __device__ T max(const T& a, const T& b)
+{
+ return make_vec4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w));
+}
+
+template
+__forceinline__ __device__ T1 max(const T1& a, const T2& b)
+{
+ return make_vec4(max(a.x, b), max(a.y, b), max(a.z, b), max(a.w, b));
+}
+
+template
+__forceinline__ __device__ T min(const T& a, const T& b)
+{
+ return make_vec4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w));
+}
+
+template
+__forceinline__ __device__ T1 min(const T1& a, const T2& b)
+{
+ return make_vec4(min(a.x, b), min(a.y, b), min(a.z, b), min(a.w, b));
+}
+
+template
+__forceinline__ __device__ T log(const T& a)
+{
+ return make_vec4(log(a.x), log(a.y), log(a.z), log(a.w));
+}
+
+__forceinline__ __device__ float reciprocal(const float& a)
+{
+ return __frcp_rd(a);
+}
+
+__forceinline__ __device__ c10::Half reciprocal(const c10::Half& a)
+{
+ return hrcp(a);
+}
+
+template
+__forceinline__ __device__ T reciprocal(const T& a)
+{
+ return make_vec4(reciprocal(a.x), reciprocal(a.y), reciprocal(a.z), reciprocal(a.w));
+}
+
+template
+__forceinline__ __device__ bool4 operator>(const T1& a, const T2& b)
+{
+ return make_vec4(a.x > b, a.y > b, a.z > b, a.w > b);
+}
+
+__forceinline__ __device__ float sigmoid(const float x)
+{
+ return 1.0f / (1.0f + expf(-x));
+}
+
+__forceinline__ __device__ float wsilu(const float x)
+{
+ return x * sigmoid(4.0f * x);
+}
+
+__forceinline__ __device__ c10::Half wsilu(const c10::Half x)
+{
+ return __float2half_rn(wsilu(__half2float(x)));
+}
+
+__forceinline__ __device__ float4 wsilu(float4 data)
+{
+ data.x = wsilu(data.x);
+ data.y = wsilu(data.y);
+ data.z = wsilu(data.z);
+ data.w = wsilu(data.w);
+ return data;
+}
+
+__forceinline__ __device__ Half4 wsilu(Half4 data)
+{
+ data.x = wsilu(data.x);
+ data.y = wsilu(data.y);
+ data.z = wsilu(data.z);
+ data.w = wsilu(data.w);
+ return data;
+}
+
+__forceinline__ __device__ float multiply_add(const float a, const float b, const float c)
+{
+ return __fmaf_rn(a, b, c);
+}
+
+__forceinline__ __device__ c10::Half multiply_add(const c10::Half a, const c10::Half b, const c10::Half c)
+{
+
+ return __hfma(a, b, c);
+}
\ No newline at end of file
diff --git a/DCVC-RT/src/layers/extensions/inference/def.h b/DCVC-RT/src/layers/extensions/inference/def.h
new file mode 100644
index 0000000..17e7782
--- /dev/null
+++ b/DCVC-RT/src/layers/extensions/inference/def.h
@@ -0,0 +1,107 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include
+
+std::tuple
+process_with_mask_cuda(const torch::Tensor& y, const torch::Tensor& scales, const torch::Tensor& means,
+ const torch::Tensor& mask, const float force_zero_thres);
+
+void combine_for_reading_2x_cuda(torch::Tensor& out, const torch::Tensor& x, const torch::Tensor& mask);
+void restore_y_2x_cuda(torch::Tensor& out, const torch::Tensor& y, const torch::Tensor& means,
+ const torch::Tensor& mask);
+void restore_y_4x_cuda(torch::Tensor& out, const torch::Tensor& y, const torch::Tensor& means,
+ const torch::Tensor& mask);
+
+void build_index_dec_cuda(torch::Tensor& out, torch::optional& cond_out,
+ const torch::Tensor& scales, const float scale_min, const float scale_max,
+ const float log_scale_min, const float log_step_recip,
+ const float skip_thres);
+
+void build_index_enc_cuda(torch::Tensor& out, torch::optional& cond_out,
+ const torch::Tensor& symbols, const torch::Tensor& scales,
+ const float scale_min, const float scale_max, const float log_scale_min,
+ const float log_step_recip, const float skip_thres);
+
+void bias_wsilu_cuda(torch::Tensor& x, const torch::Tensor& bias);
+
+void bias_shortcut_cuda(torch::Tensor& x, const torch::Tensor& bias, const torch::Tensor& shortcut);
+void bias_shortcut_no_inplace_cuda(torch::Tensor& out, const torch::Tensor& x,
+ const torch::Tensor& bias, const torch::Tensor& shortcut);
+void bias_shortcut_2_cuda(torch::Tensor& x, const torch::Tensor& bias, torch::Tensor& shortcut);
+void bias_shortcut_with_quant_step_cuda(torch::Tensor& x, const torch::Tensor& bias,
+ const torch::Tensor& quant_step, const torch::Tensor& shortcut);
+
+void bias_quant_cuda(torch::Tensor& x, const torch::Tensor& bias, const torch::Tensor& quant_step);
+
+void bias_wsilu_chunk_add_cuda(torch::Tensor& x, const torch::Tensor& bias);
+
+void bias_pixel_shuffle_2_cuda(torch::Tensor& out, const torch::Tensor& x,
+ const torch::Tensor& bias, const int C, const int N, const int W);
+void bias_pixel_shuffle_8_cuda(torch::Tensor& out, const torch::Tensor& x, const torch::Tensor& bias,
+ const int C, const int N, const int W, bool clamp);
+torch::Tensor replicate_pad_cuda(const torch::Tensor& x, const int padB, const int padR);
+
+torch::Tensor round_and_to_int8_cuda(torch::Tensor& z);
+torch::Tensor clamp_reciprocal_with_quant_cuda(const torch::Tensor& q_dec, torch::Tensor& y,
+ const float min_val);
+void add_and_multiply_cuda(torch::Tensor& x0, const torch::Tensor& x1, const torch::Tensor q);
+
+torch::Tensor bias_wsilu_depthwise_conv2d_cuda(const torch::Tensor& x, const torch::Tensor& weight,
+ const torch::Tensor& bias);
+
+class DepthConvProxy {
+public:
+ DepthConvProxy() = default;
+ ~DepthConvProxy() = default;
+
+ void set_param(const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias,
+ const torch::Tensor& dc_depth_conv_weight,
+ const torch::Tensor& dc_depth_conv_bias, const torch::Tensor& dc_conv2_weight,
+ const torch::Tensor& dc_conv2_bias, const torch::Tensor& ffn_conv1_weight,
+ const torch::Tensor& ffn_conv1_bias, const torch::Tensor& ffn_conv2_weight,
+ const torch::Tensor& ffn_conv2_bias, const bool shortcut);
+ void set_param_with_adaptor(
+ const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias,
+ const torch::Tensor& dc_depth_conv_weight, const torch::Tensor& dc_depth_conv_bias,
+ const torch::Tensor& dc_conv2_weight, const torch::Tensor& dc_conv2_bias,
+ const torch::Tensor& ffn_conv1_weight, const torch::Tensor& ffn_conv1_bias,
+ const torch::Tensor& ffn_conv2_weight, const torch::Tensor& ffn_conv2_bias,
+ const torch::Tensor& adaptor_weight, const torch::Tensor& adaptor_bias, const bool shortcut);
+ torch::Tensor forward(const torch::Tensor& x);
+ torch::Tensor forward_with_quant_step(const torch::Tensor& x, const torch::Tensor& quant_step);
+ torch::Tensor forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat,
+ const bool cat_at_front);
+
+private:
+ std::tuple forward_common(const torch::Tensor& x);
+
+private:
+ torch::Tensor _dc_conv1_weight;
+ torch::Tensor _dc_conv1_bias;
+ torch::Tensor _dc_depth_conv_weight;
+ torch::Tensor _dc_conv2_weight;
+ torch::Tensor _dc_conv2_bias;
+ torch::Tensor _ffn_conv1_weight;
+ torch::Tensor _ffn_conv1_bias;
+ torch::Tensor _ffn_conv2_weight;
+ torch::Tensor _ffn_conv2_bias;
+ bool _adaptor{ false };
+ bool _shortcut{ false };
+};
+
+class SubpelConv2xProxy {
+public:
+ SubpelConv2xProxy() = default;
+ ~SubpelConv2xProxy() = default;
+
+ void set_param(const torch::Tensor& weight, const torch::Tensor& bias, const int padding);
+ torch::Tensor forward(const torch::Tensor& x);
+ torch::Tensor forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat,
+ const bool cat_at_front);
+
+private:
+ torch::Tensor _weight;
+ torch::Tensor _bias;
+ int _padding{ 0 };
+};
diff --git a/DCVC-RT/src/layers/extensions/inference/impl.cpp b/DCVC-RT/src/layers/extensions/inference/impl.cpp
new file mode 100644
index 0000000..6419fff
--- /dev/null
+++ b/DCVC-RT/src/layers/extensions/inference/impl.cpp
@@ -0,0 +1,167 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include "def.h"
+namespace F = torch::nn::functional;
+
+void DepthConvProxy::set_param(
+ const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias,
+ const torch::Tensor& dc_depth_conv_weight, const torch::Tensor& dc_depth_conv_bias,
+ const torch::Tensor& dc_conv2_weight, const torch::Tensor& dc_conv2_bias,
+ const torch::Tensor& ffn_conv1_weight, const torch::Tensor& ffn_conv1_bias,
+ const torch::Tensor& ffn_conv2_weight, const torch::Tensor& ffn_conv2_bias, const bool shortcut)
+{
+ _dc_conv1_weight = dc_conv1_weight;
+ _dc_conv1_bias = dc_conv1_bias;
+ _dc_depth_conv_weight = dc_depth_conv_weight;
+ _dc_conv2_weight = dc_conv2_weight;
+ _dc_conv2_bias = F::conv2d(dc_depth_conv_bias.reshape({ 1, -1, 1, 1 }), dc_conv2_weight);
+ _dc_conv2_bias = _dc_conv2_bias.index({ 0, torch::indexing::Slice(), 0, 0 }) + dc_conv2_bias;
+ _ffn_conv1_weight = ffn_conv1_weight;
+ _ffn_conv1_bias = ffn_conv1_bias;
+ _ffn_conv2_weight = ffn_conv2_weight;
+ _ffn_conv2_bias = ffn_conv2_bias;
+ _shortcut = shortcut;
+}
+
+void DepthConvProxy::set_param_with_adaptor(
+ const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias,
+ const torch::Tensor& dc_depth_conv_weight, const torch::Tensor& dc_depth_conv_bias,
+ const torch::Tensor& dc_conv2_weight, const torch::Tensor& dc_conv2_bias,
+ const torch::Tensor& ffn_conv1_weight, const torch::Tensor& ffn_conv1_bias,
+ const torch::Tensor& ffn_conv2_weight, const torch::Tensor& ffn_conv2_bias,
+ const torch::Tensor& adaptor_weight, const torch::Tensor& adaptor_bias, const bool shortcut)
+{
+ _dc_conv1_weight = F::conv2d(torch::transpose(adaptor_weight, 0, 1), dc_conv1_weight);
+ _dc_conv1_weight = torch::transpose(_dc_conv1_weight, 0, 1);
+ _dc_conv1_weight = torch::cat({ _dc_conv1_weight, adaptor_weight }, 0);
+ _dc_conv1_bias = F::conv2d(adaptor_bias.reshape({ 1, -1, 1, 1 }), dc_conv1_weight);
+ _dc_conv1_bias = _dc_conv1_bias.index({ 0, torch::indexing::Slice(), 0, 0 }) + dc_conv1_bias;
+ _dc_depth_conv_weight = dc_depth_conv_weight;
+ _dc_conv2_weight = dc_conv2_weight;
+ _dc_conv2_bias = F::conv2d(dc_depth_conv_bias.reshape({ 1, -1, 1, 1 }), dc_conv2_weight);
+ _dc_conv2_bias = _dc_conv2_bias.index({ 0, torch::indexing::Slice(), 0, 0 }) + dc_conv2_bias;
+ _dc_conv2_bias = _dc_conv2_bias + adaptor_bias;
+ _ffn_conv1_weight = ffn_conv1_weight;
+ _ffn_conv1_bias = ffn_conv1_bias;
+ _ffn_conv2_weight = ffn_conv2_weight;
+ _ffn_conv2_bias = ffn_conv2_bias;
+ _shortcut = shortcut;
+ _adaptor = true;
+}
+
+std::tuple DepthConvProxy::forward_common(const torch::Tensor& x)
+{
+ auto identity = x;
+ // depthconv
+ torch::Tensor out;
+ if (_adaptor) {
+ // NOTE: Here we always fuse adaptor with the first conv1x1 (when even in_ch > out_ch).
+ // It brings larger MACs, but it faster on A100 due to lower memory cost.
+ auto out_identity = F::conv2d(identity, _dc_conv1_weight);
+ auto chunks = torch::chunk(out_identity, 2, 1);
+ out = chunks[0];
+ identity = chunks[1];
+ } else {
+ out = F::conv2d(identity, _dc_conv1_weight);
+ }
+ out = bias_wsilu_depthwise_conv2d_cuda(out, _dc_depth_conv_weight, _dc_conv1_bias);
+ out = F::conv2d(out, _dc_conv2_weight);
+
+ if (_shortcut) {
+ bias_shortcut_2_cuda(out, _dc_conv2_bias, identity);
+ } else {
+ bias_shortcut_cuda(out, _dc_conv2_bias, identity);
+ identity = out;
+ }
+ // ffn
+ out = F::conv2d(out, _ffn_conv1_weight);
+ bias_wsilu_chunk_add_cuda(out, _ffn_conv1_bias);
+ out = F::conv2d(out, _ffn_conv2_weight);
+ return { out, identity };
+}
+
+torch::Tensor DepthConvProxy::forward(const torch::Tensor& x)
+{
+ auto [out, identity] = forward_common(x);
+ bias_shortcut_cuda(out, _ffn_conv2_bias, identity);
+ return out;
+}
+
+torch::Tensor DepthConvProxy::forward_with_quant_step(const torch::Tensor& x,
+ const torch::Tensor& quant_step)
+{
+ auto [out, identity] = forward_common(x);
+ bias_shortcut_with_quant_step_cuda(out, _ffn_conv2_bias, quant_step, identity);
+ return out;
+}
+
+torch::Tensor DepthConvProxy::forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat,
+ const bool cat_at_front)
+{
+ auto [t, identity] = forward_common(x);
+
+ auto t_shape = t.sizes();
+ auto B = t_shape[0];
+ auto C = t_shape[1];
+ auto H = t_shape[2];
+ auto W = t_shape[3];
+ auto add_ch = to_cat.sizes()[1];
+ auto out = torch::empty({ B, C + add_ch, H, W }, t.options());
+ if (cat_at_front) {
+ auto t_out = out.narrow(1, add_ch, C);
+ bias_shortcut_no_inplace_cuda(t_out, t, _ffn_conv2_bias, identity);
+ out.narrow(1, 0, add_ch) = to_cat;
+ } else {
+ auto t_out = out.narrow(1, 0, C);
+ bias_shortcut_no_inplace_cuda(t_out, t, _ffn_conv2_bias, identity);
+ out.narrow(1, C, add_ch) = to_cat;
+ }
+ return out;
+}
+
+void SubpelConv2xProxy::set_param(const torch::Tensor& weight, const torch::Tensor& bias,
+ const int padding)
+{
+ _weight = weight;
+ _bias = bias;
+ _padding = padding;
+}
+
+torch::Tensor SubpelConv2xProxy::forward(const torch::Tensor& x)
+{
+ auto t = F::conv2d(x, _weight, F::Conv2dFuncOptions().padding(_padding));
+ auto t_shape = t.sizes();
+ auto B = t_shape[0];
+ auto C = t_shape[1];
+ auto H = t_shape[2];
+ auto W = t_shape[3];
+ auto out = torch::empty({ B, C / 4, H * 2, W * 2 }, t.options());
+ assert(B == 1);
+ bias_pixel_shuffle_2_cuda(out, t, _bias, C, H * W, W);
+ return out;
+}
+
+torch::Tensor SubpelConv2xProxy::forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat,
+ const bool cat_at_front)
+{
+ auto t = F::conv2d(x, _weight, F::Conv2dFuncOptions().padding(_padding));
+ auto t_shape = t.sizes();
+ auto B = t_shape[0];
+ auto C = t_shape[1];
+ auto H = t_shape[2];
+ auto W = t_shape[3];
+ auto add_ch = to_cat.sizes()[1];
+ auto out = torch::empty({ B, add_ch + C / 4, H * 2, W * 2 }, t.options());
+ assert(B == 1);
+ if (cat_at_front) {
+ auto t_out = out.narrow(1, add_ch, C / 4);
+ bias_pixel_shuffle_2_cuda(t_out, t, _bias, C, H * W, W);
+ out.narrow(1, 0, add_ch) = to_cat;
+ } else {
+ auto t_out = out.narrow(1, 0, C / 4);
+ bias_pixel_shuffle_2_cuda(t_out, t, _bias, C, H * W, W);
+ out.narrow(1, C / 4, add_ch) = to_cat;
+ }
+ return out;
+}
diff --git a/DCVC-RT/src/layers/extensions/inference/kernel.cu b/DCVC-RT/src/layers/extensions/inference/kernel.cu
new file mode 100644
index 0000000..41ebc78
--- /dev/null
+++ b/DCVC-RT/src/layers/extensions/inference/kernel.cu
@@ -0,0 +1,1150 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include
+#include
+#include