diff --git a/DCVC-RT/README.md b/DCVC-RT/README.md new file mode 100644 index 0000000..8a0018e --- /dev/null +++ b/DCVC-RT/README.md @@ -0,0 +1,113 @@ +# Introduction + +Official Pytorch implementation for DCVC-RT: [Towards Practical **R**eal-**T**ime Neural Video Compression](https://arxiv.org/abs/2502.20762), in CVPR 2025. + +# Prerequisites +* Python 3.12 and conda, get [Conda](https://www.anaconda.com/) +* CUDA 12.6 (other versions may also work. Make sure the CUDA version matches with pytorch.) +* pytorch (We have tested that pytorch-2.6 works. Other versions may also work.) +* Environment + ``` + conda create -n $YOUR_PY_ENV_NAME python=3.12 + conda activate $YOUR_PY_ENV_NAME + + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 + pip install -r requirements.txt + ``` + +# Test dataset + +We support arbitrary original resolution. The input video resolution will be padded automatically. The reconstructed video will be cropped back to the original size. The distortion (PSNR) is calculated at original resolution. + +## YUV 420 content + +Put *.yuv in the folder structure similar to the following structure. + + /media/data/HEVC_B/ + - BQTerrace_1920x1080_60.yuv + - BasketballDrive_1920x1080_50.yuv + - ... + /media/data/HEVC_D/ + /media/data/HEVC_C/ + ... + +The dataset structure can be seen in dataset_config_example_yuv420.json. + +## RGB content + +We highly suggest testing YUV420 content. To test RGB content, please refer to the [DCVC-FM](../DCVC-FM) folder. + +# Build the project +Please build the C++ code to support bitstream writing and customized CUDA kernels to fuse operations. + +```bash +sudo apt-get install cmake g++ ninja-build +conda activate $YOUR_PY_ENV_NAME +cd ./src/cpp/ +pip install . +cd ../layers/extensions/inference/ +pip install . +``` + +# CPU performance scaling + +Note that the arithmetic coding runs on the CPU, please make sure your CPU runs at high performance while writing the actual bitstream. Otherwise, the arithmetic coding may take a long time. + +Check the CPU frequency by +``` +grep -E '^model name|^cpu MHz' /proc/cpuinfo +``` + +Run the following command to maximum CPU frequency +``` +echo performance | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +``` + +Run the following command to recover the default frequency +``` +echo ondemand | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +``` + +# Pretrained models + +* Download [our pretrained models](https://1drv.ms/f/c/2866592d5c55df8c/Esu0KJ-I2kxCjEP565ARx_YB88i0UnR6XnODqFcvZs4LcA?e=by8CO8) and put them into ./checkpoints folder. +* There are 2 models, one for image coding and the other for video coding. + +# Test the models + +Example to test pretrained model with four rate points: +```bash + python test_video.py --model_path_i ./checkpoints/cvpr2025_image.pth.tar --model_path_p ./checkpoints/cvpr2025_video.pth.tar --rate_num 4 --test_config ./dataset_config_example_yuv420.json --cuda 1 -w 1 --write_stream 1 --force_zero_thres 0.12 --output_path output.json --force_intra_period -1 --reset_interval 64 --force_frame_num -1 --check_existing 0 +``` + +It is recommended that the ```-w``` number is equal to your GPU number. + +You can also specify different ```--rate_num``` values (2~64) to test finer bitrate adjustment. + +# Comparing with other method +Bit saving over VTM-17.0 (UVG all frames with single intra-frame setting (i.e. intra-period = –1) and YUV420 colorspace.) + + + +The BD-Rate and encoding/decoding speed on Nvidia A100 GPU + + + +# Acknowledgement +The implementation is based on [CompressAI](https://github.com/InterDigitalInc/CompressAI). + +# Citation +If you find this work useful for your research, please cite: + +``` +@inproceedings{jia2025towards, + title={Towards Practical Real-Time Neural Video Compression}, + author={Jia, Zhaoyang and Li, Bin and Li, Jiahao and Xie, Wenxuan and Qi, Linfeng and Li, Houqiang and Lu, Yan}, + booktitle={{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, + {CVPR} 2025, Nashville, TN, USA, June 11-25, 2024}, + year={2025} +} +``` + +# Trademarks +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies. \ No newline at end of file diff --git a/DCVC-RT/assets/RD-Curve.png b/DCVC-RT/assets/RD-Curve.png new file mode 100644 index 0000000..1d93daf Binary files /dev/null and b/DCVC-RT/assets/RD-Curve.png differ diff --git a/DCVC-RT/assets/bd_rate_speed.png b/DCVC-RT/assets/bd_rate_speed.png new file mode 100644 index 0000000..6da70d1 Binary files /dev/null and b/DCVC-RT/assets/bd_rate_speed.png differ diff --git a/DCVC-RT/dataset_config_example_yuv420.json b/DCVC-RT/dataset_config_example_yuv420.json new file mode 100644 index 0000000..3414f49 --- /dev/null +++ b/DCVC-RT/dataset_config_example_yuv420.json @@ -0,0 +1,100 @@ +{ + "root_path": "/media/data/", + "test_classes": { + "UVG": { + "test": 1, + "base_path": "UVG", + "src_type": "yuv420", + "sequences": { + "Beauty_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}, + "Bosphorus_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}, + "HoneyBee_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}, + "Jockey_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}, + "ReadySteadyGo_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}, + "ShakeNDry_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 300, "intra_period": -1}, + "YachtRide_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1} + } + }, + "MCL-JCV": { + "test": 1, + "base_path": "MCL-JCV", + "src_type": "yuv420", + "sequences": { + "videoSRC01_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC02_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC03_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC04_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC05_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1}, + "videoSRC06_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1}, + "videoSRC07_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1}, + "videoSRC08_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1}, + "videoSRC09_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1}, + "videoSRC10_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC11_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC12_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC13_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC14_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC15_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC16_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC17_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1}, + "videoSRC18_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1}, + "videoSRC19_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC20_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1}, + "videoSRC21_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1}, + "videoSRC22_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1}, + "videoSRC23_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1}, + "videoSRC24_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1}, + "videoSRC25_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1}, + "videoSRC26_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC27_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC28_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}, + "videoSRC29_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1}, + "videoSRC30_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1} + } + }, + "HEVC_B": { + "test": 1, + "base_path": "HEVC_B", + "src_type": "yuv420", + "sequences": { + "BQTerrace_1920x1080_60.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}, + "BasketballDrive_1920x1080_50.yuv": {"width": 1920, "height": 1080, "frames": 500, "intra_period": -1}, + "Cactus_1920x1080_50.yuv": {"width": 1920, "height": 1080, "frames": 500, "intra_period": -1}, + "Kimono1_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 240, "intra_period": -1}, + "ParkScene_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 240, "intra_period": -1} + } + }, + "HEVC_E": { + "test": 1, + "base_path": "HEVC_E", + "src_type": "yuv420", + "sequences": { + "FourPeople_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1}, + "Johnny_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1}, + "KristenAndSara_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1} + } + }, + "HEVC_C": { + "test": 1, + "base_path": "HEVC_C", + "src_type": "yuv420", + "sequences": { + "BQMall_832x480_60.yuv": {"width": 832, "height": 480, "frames": 600, "intra_period": -1}, + "BasketballDrill_832x480_50.yuv": {"width": 832, "height": 480, "frames": 500, "intra_period": -1}, + "PartyScene_832x480_50.yuv": {"width": 832, "height": 480, "frames": 500, "intra_period": -1}, + "RaceHorses_832x480_30.yuv": {"width": 832, "height": 480, "frames": 300, "intra_period": -1} + } + }, + "HEVC_D": { + "test": 1, + "base_path": "HEVC_D", + "src_type": "yuv420", + "sequences": { + "BasketballPass_416x240_50.yuv": {"width": 416, "height": 240, "frames": 500, "intra_period": -1}, + "BlowingBubbles_416x240_50.yuv": {"width": 416, "height": 240, "frames": 500, "intra_period": -1}, + "BQSquare_416x240_60.yuv": {"width": 416, "height": 240, "frames": 600, "intra_period": -1}, + "RaceHorses_416x240_30.yuv": {"width": 416, "height": 240, "frames": 300, "intra_period": -1} + } + } + } +} \ No newline at end of file diff --git a/DCVC-RT/requirements.txt b/DCVC-RT/requirements.txt new file mode 100644 index 0000000..21ad776 --- /dev/null +++ b/DCVC-RT/requirements.txt @@ -0,0 +1,7 @@ +numpy>=1.20.0 +scipy +matplotlib +tqdm +bd-metric +pillow +pybind11 diff --git a/DCVC-RT/src/cpp/py_rans/py_rans.cpp b/DCVC-RT/src/cpp/py_rans/py_rans.cpp new file mode 100644 index 0000000..1f5d047 --- /dev/null +++ b/DCVC-RT/src/cpp/py_rans/py_rans.cpp @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "py_rans.h" + +#include +#include +#include +#include +#include + +namespace py = pybind11; + +RansEncoder::RansEncoder() +{ + m_encoder0 = std::make_shared(); + m_encoder1 = std::make_shared(); +} + +void RansEncoder::encode_y(const py::array_t& symbols, const int cdf_group_index) +{ + py::buffer_info symbols_buf = symbols.request(); + int16_t* symbols_ptr = static_cast(symbols_buf.ptr); + + int symbolSize = static_cast(symbols.size()); + if (m_use_two_encoders) { + int symbolSize0 = symbolSize / 2; + int symbolSize1 = symbolSize - symbolSize0; + + auto vec_symbols0 = std::make_shared>(symbolSize0); + memcpy(vec_symbols0->data(), symbols_ptr, symbolSize0 * sizeof(int16_t)); + m_encoder0->encode_y(vec_symbols0, cdf_group_index); + auto vec_symbols1 = std::make_shared>(symbolSize1); + memcpy(vec_symbols1->data(), symbols_ptr + symbolSize0, symbolSize1 * sizeof(int16_t)); + m_encoder1->encode_y(vec_symbols1, cdf_group_index); + } else { + auto vec_symbols0 = std::make_shared>(symbolSize); + memcpy(vec_symbols0->data(), symbols_ptr, symbolSize * sizeof(int16_t)); + m_encoder0->encode_y(vec_symbols0, cdf_group_index); + } +} + +void RansEncoder::encode_z(const py::array_t& symbols, const int cdf_group_index, + const int start_offset, const int per_channel_size) +{ + py::buffer_info symbols_buf = symbols.request(); + int8_t* symbols_ptr = static_cast(symbols_buf.ptr); + + int symbolSize = static_cast(symbols.size()); + if (m_use_two_encoders) { + int symbolSize0 = symbolSize / 2; + int symbolSize1 = symbolSize - symbolSize0; + int channel_half = symbolSize0 / per_channel_size; + + auto vec_symbols0 = std::make_shared>(symbolSize0); + memcpy(vec_symbols0->data(), symbols_ptr, symbolSize0 * sizeof(int8_t)); + m_encoder0->encode_z(vec_symbols0, cdf_group_index, start_offset, per_channel_size); + auto vec_symbols1 = std::make_shared>(symbolSize1); + memcpy(vec_symbols1->data(), symbols_ptr + symbolSize0, symbolSize1 * sizeof(int8_t)); + m_encoder1->encode_z(vec_symbols1, cdf_group_index, start_offset + channel_half, + per_channel_size); + } else { + auto vec_symbols0 = std::make_shared>(symbolSize); + memcpy(vec_symbols0->data(), symbols_ptr, symbolSize * sizeof(int8_t)); + m_encoder0->encode_z(vec_symbols0, cdf_group_index, start_offset, per_channel_size); + } +} + +int RansEncoder::add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes, + const py::array_t& offsets) +{ + py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); + py::buffer_info offsets_buf = offsets.request(); + int32_t* cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); + int32_t* offsets_ptr = static_cast(offsets_buf.ptr); + + int cdf_num = static_cast(cdfs_sizes.size()); + auto vec_cdfs_sizes = std::make_shared>(cdf_num); + memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num); + auto vec_offsets = std::make_shared>(offsets.size()); + memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size()); + + int per_vector_size = static_cast(cdfs.size() / cdf_num); + auto vec_cdfs = std::make_shared>>(cdf_num); + auto cdfs_raw = cdfs.unchecked<2>(); + for (int i = 0; i < cdf_num; i++) { + std::vector t(per_vector_size); + memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size); + vec_cdfs->at(i) = t; + } + + int cdf_idx = m_encoder0->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets); + m_encoder1->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets); + return cdf_idx; +} + +void RansEncoder::empty_cdf_buffer() +{ + m_encoder0->empty_cdf_buffer(); + m_encoder1->empty_cdf_buffer(); +} + +void RansEncoder::flush() +{ + m_encoder0->flush(); + m_encoder1->flush(); +} + +py::array_t RansEncoder::get_encoded_stream() +{ + if (m_use_two_encoders) { + auto result0 = m_encoder0->get_encoded_stream(); + int nbytes0 = static_cast(result0->size()); + auto result1 = m_encoder1->get_encoded_stream(); + int nbytes1 = static_cast(result1->size()); + + int identical_bytes = 0; + int check_bytes = std::min(nbytes0, nbytes1); + check_bytes = std::min(check_bytes, 8); + for (int i = 0; i < check_bytes; i++) { + if (result0->at(nbytes0 - 1 - i) != 0) { + break; + } + if (result1->at(nbytes1 - 1 - i) != 0) { + break; + } + identical_bytes++; + } + if (identical_bytes == 0 && result0->at(nbytes0 - 1) == result1->at(nbytes1 - 1)) { + identical_bytes = 1; + } + + py::array_t stream(nbytes0 + nbytes1 - identical_bytes); + py::buffer_info stream_buf = stream.request(); + uint8_t* stream_ptr = static_cast(stream_buf.ptr); + + std::copy(result0->begin(), result0->end(), stream_ptr); + std::reverse_copy(result1->begin(), result1->end() - identical_bytes, stream_ptr + nbytes0); + return stream; + } + + auto result0 = m_encoder0->get_encoded_stream(); + int nbytes0 = static_cast(result0->size()); + + py::array_t stream(nbytes0); + py::buffer_info stream_buf = stream.request(); + uint8_t* stream_ptr = static_cast(stream_buf.ptr); + + std::copy(result0->begin(), result0->end(), stream_ptr); + return stream; +} + +void RansEncoder::reset() +{ + m_encoder0->reset(); + m_encoder1->reset(); +} + +void RansEncoder::set_use_two_encoders(bool b) +{ + m_use_two_encoders = b; +} + +bool RansEncoder::get_use_two_encoders() +{ + return m_use_two_encoders; +} + +RansDecoder::RansDecoder() +{ + m_decoder0 = std::make_shared(); + m_decoder1 = std::make_shared(); +} + +void RansDecoder::set_stream(const py::array_t& encoded) +{ + py::buffer_info encoded_buf = encoded.request(); + const uint8_t* encoded_ptr = static_cast(encoded_buf.ptr); + const int encoded_size = static_cast(encoded.size()); + auto stream0 = std::make_shared>(encoded.size()); + std::copy(encoded_ptr, encoded_ptr + encoded_size, stream0->data()); + m_decoder0->set_stream(stream0); + if (m_use_two_decoders) { + auto stream1 = std::make_shared>(encoded.size()); + std::reverse_copy(encoded_ptr, encoded_ptr + encoded_size, stream1->data()); + m_decoder1->set_stream(stream1); + } +} + +void RansDecoder::decode_y(const py::array_t& indexes, const int cdf_group_index) +{ + py::buffer_info indexes_buf = indexes.request(); + uint8_t* indexes_ptr = static_cast(indexes_buf.ptr); + + int indexSize = static_cast(indexes.size()); + if (m_use_two_decoders) { + int indexSize0 = indexSize / 2; + int indexSize1 = indexSize - indexSize0; + + auto vec_indexes0 = std::make_shared>(indexSize0); + std::copy(indexes_ptr, indexes_ptr + indexSize0, vec_indexes0->data()); + m_decoder0->decode_y(vec_indexes0, cdf_group_index); + + auto vec_indexes1 = std::make_shared>(indexSize1); + std::copy(indexes_ptr + indexSize0, indexes_ptr + indexSize, vec_indexes1->data()); + m_decoder1->decode_y(vec_indexes1, cdf_group_index); + } else { + auto vec_indexes0 = std::make_shared>(indexSize); + std::copy(indexes_ptr, indexes_ptr + indexSize, vec_indexes0->data()); + m_decoder0->decode_y(vec_indexes0, cdf_group_index); + } +} + +py::array_t RansDecoder::decode_and_get_y(const py::array_t& indexes, + const int cdf_group_index) +{ + decode_y(indexes, cdf_group_index); + return get_decoded_tensor(); +} + +void RansDecoder::decode_z(const int total_size, const int cdf_group_index, const int start_offset, + const int per_channel_size) +{ + if (m_use_two_decoders) { + int symbolSize0 = total_size / 2; + int symbolSize1 = total_size - symbolSize0; + int channel_half = symbolSize0 / per_channel_size; + m_decoder0->decode_z(symbolSize0, cdf_group_index, start_offset, per_channel_size); + m_decoder1->decode_z(symbolSize1, cdf_group_index, start_offset + channel_half, + per_channel_size); + } else { + m_decoder0->decode_z(total_size, cdf_group_index, start_offset, per_channel_size); + } +} + +py::array_t RansDecoder::get_decoded_tensor() +{ + if (m_use_two_decoders) { + auto result0 = m_decoder0->get_decoded_tensor(); + const int total_size0 = static_cast(result0->size()); + + auto result1 = m_decoder1->get_decoded_tensor(); + const int total_size1 = static_cast(result1->size()); + py::array_t output(total_size0 + total_size1); + py::buffer_info buf = output.request(); + int8_t* buf_ptr = static_cast(buf.ptr); + std::copy(result0->begin(), result0->end(), buf_ptr); + std::copy(result1->begin(), result1->end(), buf_ptr + total_size0); + + return output; + } + + auto result0 = m_decoder0->get_decoded_tensor(); + const int total_size0 = static_cast(result0->size()); + + py::array_t output(total_size0); + py::buffer_info buf = output.request(); + int8_t* buf_ptr = static_cast(buf.ptr); + std::copy(result0->begin(), result0->end(), buf_ptr); + + return output; +} + +int RansDecoder::add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes, + const py::array_t& offsets) +{ + py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); + py::buffer_info offsets_buf = offsets.request(); + int32_t* cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); + int32_t* offsets_ptr = static_cast(offsets_buf.ptr); + + int cdf_num = static_cast(cdfs_sizes.size()); + auto vec_cdfs_sizes = std::make_shared>(cdf_num); + memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num); + auto vec_offsets = std::make_shared>(offsets.size()); + memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size()); + + int per_vector_size = static_cast(cdfs.size() / cdf_num); + auto vec_cdfs = std::make_shared>>(cdf_num); + auto cdfs_raw = cdfs.unchecked<2>(); + for (int i = 0; i < cdf_num; i++) { + std::vector t(per_vector_size); + memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size); + vec_cdfs->at(i) = t; + } + int cdf_idx = m_decoder0->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets); + m_decoder1->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets); + return cdf_idx; +} + +void RansDecoder::empty_cdf_buffer() +{ + m_decoder0->empty_cdf_buffer(); + m_decoder1->empty_cdf_buffer(); +} + +void RansDecoder::set_use_two_decoders(bool b) +{ + m_use_two_decoders = b; +} + +bool RansDecoder::get_use_two_decoders() +{ + return m_use_two_decoders; +} + +std::vector pmf_to_quantized_cdf(const std::vector& pmf, int precision) +{ + /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal + * although it's only run once per model after training. See TF/compression + * implementation for an optimized version. */ + + std::vector cdf(pmf.size() + 1); + cdf[0] = 0; /* freq 0 */ + + std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { + return static_cast(std::round(p * (1 << precision)) + 0.5); + }); + + const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); + + std::transform(cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { + return static_cast((((1ull << precision) * p) / total)); + }); + + std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); + cdf.back() = 1 << precision; + + for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { + if (cdf[i] == cdf[i + 1]) { + /* Try to steal frequency from low-frequency symbols */ + uint32_t best_freq = ~0u; + int best_steal = -1; + for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { + uint32_t freq = cdf[j + 1] - cdf[j]; + if (freq > 1 && freq < best_freq) { + best_freq = freq; + best_steal = j; + } + } + + assert(best_steal != -1); + + if (best_steal < i) { + for (int j = best_steal + 1; j <= i; ++j) { + cdf[j]--; + } + } else { + assert(best_steal > i); + for (int j = i + 1; j <= best_steal; ++j) { + cdf[j]++; + } + } + } + } + + assert(cdf[0] == 0); + assert(cdf.back() == (1u << precision)); + for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { + assert(cdf[i + 1] > cdf[i]); + } + + return cdf; +} + +PYBIND11_MODULE(MLCodec_extensions_cpp, m) +{ + py::class_(m, "RansEncoder") + .def(py::init<>()) + .def("encode_y", &RansEncoder::encode_y) + .def("encode_z", &RansEncoder::encode_z) + .def("flush", &RansEncoder::flush) + .def("get_encoded_stream", &RansEncoder::get_encoded_stream) + .def("reset", &RansEncoder::reset) + .def("add_cdf", &RansEncoder::add_cdf) + .def("empty_cdf_buffer", &RansEncoder::empty_cdf_buffer) + .def("set_use_two_encoders", &RansEncoder::set_use_two_encoders) + .def("get_use_two_encoders", &RansEncoder::get_use_two_encoders); + + py::class_(m, "RansDecoder") + .def(py::init<>()) + .def("set_stream", &RansDecoder::set_stream) + .def("decode_y", &RansDecoder::decode_y) + .def("decode_and_get_y", &RansDecoder::decode_and_get_y) + .def("decode_z", &RansDecoder::decode_z) + .def("get_decoded_tensor", &RansDecoder::get_decoded_tensor) + .def("add_cdf", &RansDecoder::add_cdf) + .def("empty_cdf_buffer", &RansDecoder::empty_cdf_buffer) + .def("set_use_two_decoders", &RansDecoder::set_use_two_decoders) + .def("get_use_two_decoders", &RansDecoder::get_use_two_decoders); + + m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, "Return quantized CDF for a given PMF"); +} diff --git a/DCVC-RT/src/cpp/py_rans/py_rans.h b/DCVC-RT/src/cpp/py_rans/py_rans.h new file mode 100644 index 0000000..c7d223a --- /dev/null +++ b/DCVC-RT/src/cpp/py_rans/py_rans.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include "rans.h" +#include + +#include +#include +#include + +namespace py = pybind11; + +// the classes in this file only perform the type conversion +// from python type (numpy) to C++ type (vector) +class RansEncoder { +public: + RansEncoder(); + + RansEncoder(const RansEncoder&) = delete; + RansEncoder(RansEncoder&&) = delete; + RansEncoder& operator=(const RansEncoder&) = delete; + RansEncoder& operator=(RansEncoder&&) = delete; + + void encode_y(const py::array_t& symbols, const int cdf_group_index); + void encode_z(const py::array_t& symbols, const int cdf_group_index, + const int start_offset, const int per_channel_size); + void flush(); + py::array_t get_encoded_stream(); + void reset(); + int add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes, + const py::array_t& offsets); + void empty_cdf_buffer(); + void set_use_two_encoders(bool b); + bool get_use_two_encoders(); + +private: + std::shared_ptr m_encoder0; + std::shared_ptr m_encoder1; + bool m_use_two_encoders{ false }; +}; + +class RansDecoder { +public: + RansDecoder(); + + RansDecoder(const RansDecoder&) = delete; + RansDecoder(RansDecoder&&) = delete; + RansDecoder& operator=(const RansDecoder&) = delete; + RansDecoder& operator=(RansDecoder&&) = delete; + + void set_stream(const py::array_t&); + + void decode_y(const py::array_t& indexes, const int cdf_group_index); + py::array_t decode_and_get_y(const py::array_t& indexes, const int cdf_group_index); + void decode_z(const int total_size, const int cdf_group_index, const int start_offset, + const int per_channel_size); + py::array_t get_decoded_tensor(); + int add_cdf(const py::array_t& cdfs, const py::array_t& cdfs_sizes, + const py::array_t& offsets); + void empty_cdf_buffer(); + void set_use_two_decoders(bool b); + bool get_use_two_decoders(); + +private: + std::shared_ptr m_decoder0; + std::shared_ptr m_decoder1; + bool m_use_two_decoders{ false }; +}; + +std::vector pmf_to_quantized_cdf(const std::vector& pmf, int precision); diff --git a/DCVC-RT/src/cpp/py_rans/rans.cpp b/DCVC-RT/src/cpp/py_rans/rans.cpp new file mode 100644 index 0000000..2c42bac --- /dev/null +++ b/DCVC-RT/src/cpp/py_rans/rans.cpp @@ -0,0 +1,534 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Rans64 extensions from: + * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ + * Unbounded range coding from: + * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc + **/ + +#include "rans.h" + +#include +#include +#include + +constexpr uint16_t bypass_precision = 2; /* number of bits in bypass mode */ +constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; + +inline void RansEncPutBits(RansState& r, uint8_t*& ptr, uint32_t val) +{ + RansAssert(bypass_precision <= 8); + RansAssert(val < (1u << bypass_precision)); + + constexpr uint32_t freq = 1 << (SCALE_BITS - bypass_precision); + constexpr uint32_t x_max = freq << ENC_RENORM_SHIFT_BITS; + while (r >= x_max) { + *(--ptr) = static_cast(r & 0xff); + r >>= 8; + } + + r = (r << bypass_precision) | val; +} + +inline uint32_t RansDecGetBits(RansState& r, uint8_t*& ptr) +{ + uint32_t val = r & ((1u << bypass_precision) - 1); + + /* Re-normalize */ + r = r >> bypass_precision; + if (r < RANS_BYTE_L) { + r = (r << 8) | *ptr++; + RansAssert(r >= RANS_BYTE_L); + } + + return val; +} + +RansEncoderLib::RansEncoderLib() +{ + _stream = std::make_shared>(); +} + +int RansEncoderLib::add_cdf(const std::shared_ptr>> cdfs, + const std::shared_ptr> cdfs_sizes, + const std::shared_ptr> offsets) +{ + + auto ransSymbols = std::make_shared>>(cdfs->size()); + for (int i = 0; i < static_cast(cdfs->size()); i++) { + const int32_t* cdf = cdfs->at(i).data(); + std::vector ransSym(cdfs->at(i).size()); + const int ransSize = static_cast(ransSym.size() - 1); + for (int j = 0; j < ransSize; j++) { + ransSym[j] = RansSymbol( + { static_cast(cdf[j]), static_cast(cdf[j + 1] - cdf[j]) }); + } + ransSymbols->at(i) = ransSym; + } + + _ransSymbols.push_back(ransSymbols); + _cdfs_sizes.push_back(cdfs_sizes); + _offsets.push_back(offsets); + return static_cast(_ransSymbols.size()) - 1; +} + +void RansEncoderLib::empty_cdf_buffer() +{ + _ransSymbols.clear(); + _cdfs_sizes.clear(); + _offsets.clear(); +} + +FORCE_INLINE void RansEncoderLib::encode_one_symbol(uint8_t*& ptr, RansState& rans, const int32_t symbol, + const int32_t cdf_size, const int32_t offset, + const std::vector& ransSymbols) +{ + const int32_t max_value = cdf_size - 2; + int32_t value = symbol - offset; + + uint32_t raw_val = 0; + if (value < 0) { + raw_val = -2 * value - 1; + value = max_value; + } else if (value >= max_value) { + raw_val = 2 * (value - max_value); + value = max_value; + } + + if (value == max_value) { + std::vector bypassBins; + bypassBins.reserve(20); + /* Determine the number of bypasses (in bypass_precision size) needed to + * encode the raw value. */ + int32_t n_bypass = 0; + while ((raw_val >> (n_bypass * bypass_precision)) != 0) { + ++n_bypass; + } + + /* Encode number of bypasses */ + int32_t val = n_bypass; + while (val >= max_bypass_val) { + bypassBins.push_back(max_bypass_val); + val -= max_bypass_val; + } + bypassBins.push_back(static_cast(val)); + + /* Encode raw value */ + for (int32_t j = 0; j < n_bypass; ++j) { + const int32_t val1 = (raw_val >> (j * bypass_precision)) & max_bypass_val; + bypassBins.push_back(static_cast(val1)); + } + + for (auto it = bypassBins.rbegin(); it < bypassBins.rend(); it++) { + RansEncPutBits(rans, ptr, *it); + } + } + RansEncPut(rans, ptr, ransSymbols[value].start, ransSymbols[value].range); +} + +void RansEncoderLib::encode_y(const std::shared_ptr> symbols, + const int cdf_group_index) +{ + PendingTask p; + p.workType = WorkType::EncodeDecodeY; + p.symbols_y = symbols; + p.cdf_group_index = cdf_group_index; + m_pendingEncodingList.push_back(p); +} + +void RansEncoderLib::encode_z(const std::shared_ptr> symbols, + const int cdf_group_index, const int start_offset, + const int per_channel_size) +{ + PendingTask p; + p.workType = WorkType::EncodeDecodeZ; + p.symbols_z = symbols; + p.cdf_group_index = cdf_group_index; + p.start_offset = start_offset; + p.per_channel_size = per_channel_size; + m_pendingEncodingList.push_back(p); +} +#include +FORCE_INLINE void RansEncoderLib::encode_y_internal(uint8_t*& ptr, RansState& rans, + const std::shared_ptr> symbols, + const int cdf_group_index) +{ + // backward loop on symbols from the end; + const int16_t* symbols_ptr = symbols->data(); + const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data(); + const int32_t* offsets_ptr = _offsets[cdf_group_index]->data(); + const int symbol_size = static_cast(symbols->size()); + + for (int i = symbol_size - 1; i >= 0; i--) { + const int32_t combined_symbol = symbols_ptr[i]; + const int32_t cdf_idx = combined_symbol & 0xff; + const int32_t s = combined_symbol >> 8; + encode_one_symbol(ptr, rans, s, cdfs_sizes_ptr[cdf_idx], offsets_ptr[cdf_idx], + _ransSymbols[cdf_group_index]->at(cdf_idx)); + } +} + +FORCE_INLINE void RansEncoderLib::encode_z_internal(uint8_t*& ptr, RansState& rans, + const std::shared_ptr> symbols, + const int cdf_group_index, const int start_offset, + const int per_channel_size) +{ + // backward loop on symbols from the end; + const int8_t* symbols_ptr = symbols->data(); + const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data(); + const int32_t* offsets_ptr = _offsets[cdf_group_index]->data(); + const int symbol_size = static_cast(symbols->size()); + + for (int i = symbol_size - 1; i >= 0; i--) { + const int32_t cdf_idx = i / per_channel_size + start_offset; + encode_one_symbol(ptr, rans, symbols_ptr[i], cdfs_sizes_ptr[cdf_idx], offsets_ptr[cdf_idx], + _ransSymbols[cdf_group_index]->at(cdf_idx)); + } +} + +void RansEncoderLib::flush() +{ + RansState rans; + RansEncInit(rans); + + int32_t total_symbol_size = 0; + for (auto it = m_pendingEncodingList.begin(); it != m_pendingEncodingList.end(); it++) { + if (it->workType == WorkType::EncodeDecodeY) { + total_symbol_size += static_cast(it->symbols_y->size()); + } else if (it->workType == WorkType::EncodeDecodeZ) { + total_symbol_size += static_cast(it->symbols_z->size()); + } + } + + if (total_symbol_size == 0) { + _stream->resize(0); + return; + } + + uint8_t* output = new uint8_t[total_symbol_size]; // too much space ? + uint8_t* ptrEnd = output + total_symbol_size; + uint8_t* ptr = ptrEnd; + assert(ptr != nullptr); + + for (auto it = m_pendingEncodingList.rbegin(); it != m_pendingEncodingList.rend(); it++) { + PendingTask p = *it; + if (p.workType == WorkType::EncodeDecodeY) { + encode_y_internal(ptr, rans, p.symbols_y, p.cdf_group_index); + } else if (p.workType == WorkType::EncodeDecodeZ) { + encode_z_internal(ptr, rans, p.symbols_z, p.cdf_group_index, p.start_offset, + p.per_channel_size); + } + } + + RansEncFlush(rans, ptr); + + const int nbytes = static_cast(std::distance(ptr, ptrEnd)); + + _stream->resize(nbytes); + memcpy(_stream->data(), ptr, nbytes); + delete[] output; +} + +std::shared_ptr> RansEncoderLib::get_encoded_stream() +{ + return _stream; +} + +void RansEncoderLib::reset() +{ + m_pendingEncodingList.clear(); + _stream->clear(); +} + +RansEncoderLibMultiThread::RansEncoderLibMultiThread() + : RansEncoderLib() + , m_finish(false) + , m_result_ready(false) + , m_thread(std::thread(&RansEncoderLibMultiThread::worker, this)) +{ +} +RansEncoderLibMultiThread::~RansEncoderLibMultiThread() +{ + { + std::lock_guard lk(m_mutex_pending); + std::lock_guard lk1(m_mutex_result); + m_finish = true; + } + m_cv_pending.notify_one(); + m_cv_result.notify_one(); + m_thread.join(); +} + +void RansEncoderLibMultiThread::flush() +{ + PendingTask p; + p.workType = WorkType::Flush; + { + std::unique_lock lk(m_mutex_pending); + m_pending.push_back(p); + } + m_cv_pending.notify_one(); +} + +std::shared_ptr> RansEncoderLibMultiThread::get_encoded_stream() +{ + std::unique_lock lk(m_mutex_result); + m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; }); + return RansEncoderLib::get_encoded_stream(); +} + +void RansEncoderLibMultiThread::reset() +{ + RansEncoderLib::reset(); + std::lock_guard lk(m_mutex_result); + m_result_ready = false; +} + +void RansEncoderLibMultiThread::worker() +{ + while (!m_finish) { + std::unique_lock lk(m_mutex_pending); + m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; }); + if (m_finish) { + lk.unlock(); + break; + } + if (m_pending.size() == 0) { + lk.unlock(); + // std::cout << "contine in worker" << std::endl; + continue; + } + while (m_pending.size() > 0) { + auto p = m_pending.front(); + m_pending.pop_front(); + lk.unlock(); + if (p.workType == WorkType::Flush) { + RansEncoderLib::flush(); + { + std::lock_guard lk_result(m_mutex_result); + m_result_ready = true; + } + m_cv_result.notify_one(); + } + lk.lock(); + } + lk.unlock(); + } +} + +void RansDecoderLib::set_stream(const std::shared_ptr> encoded) +{ + _stream = encoded; + _ptr8 = (uint8_t*)(_stream->data()); + RansDecInit(_rans, _ptr8); +} + +int RansDecoderLib::add_cdf(const std::shared_ptr>> cdfs, + const std::shared_ptr> cdfs_sizes, + const std::shared_ptr> offsets) +{ + _cdfs.push_back(cdfs); + _cdfs_sizes.push_back(cdfs_sizes); + _offsets.push_back(offsets); + return static_cast(_cdfs.size()) - 1; +} + +void RansDecoderLib::empty_cdf_buffer() +{ + _cdfs.clear(); + _cdfs_sizes.clear(); + _offsets.clear(); +} + +FORCE_INLINE int8_t RansDecoderLib::decode_one_symbol(const int32_t* cdf, const int32_t cdf_size, + const int32_t offset) +{ + const int32_t max_value = cdf_size - 2; + const int32_t cum_freq = static_cast(RansDecGet(_rans)); + + int s = 1; + while (cdf[s++] <= cum_freq) { + } + s -= 2; + + RansDecAdvance(_rans, _ptr8, cdf[s], cdf[s + 1] - cdf[s]); + + int32_t value = static_cast(s); + + if (value == max_value) { + /* Bypass decoding mode */ + int32_t val = RansDecGetBits(_rans, _ptr8); + int32_t n_bypass = val; + + while (val == max_bypass_val) { + val = RansDecGetBits(_rans, _ptr8); + n_bypass += val; + } + + int32_t raw_val = 0; + for (int j = 0; j < n_bypass; ++j) { + val = RansDecGetBits(_rans, _ptr8); + raw_val |= val << (j * bypass_precision); + } + value = raw_val >> 1; + if (raw_val & 1) { + value = -value - 1; + } else { + value += max_value; + } + } + + return static_cast(value + offset); +} + +void RansDecoderLib::decode_y(const std::shared_ptr> indexes, + const int cdf_group_index) +{ + int index_size = static_cast(indexes->size()); + m_decoded = std::make_shared>(index_size); + + int8_t* outout_ptr = m_decoded->data(); + const uint8_t* indexes_ptr = indexes->data(); + const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data(); + const int32_t* offsets_ptr = _offsets[cdf_group_index]->data(); + const auto& cdfs = _cdfs[cdf_group_index]; + for (int i = 0; i < index_size; ++i) { + const int32_t cdf_idx = indexes_ptr[i]; + outout_ptr[i] = decode_one_symbol(cdfs->at(cdf_idx).data(), cdfs_sizes_ptr[cdf_idx], + offsets_ptr[cdf_idx]); + } +} + +void RansDecoderLib::decode_z(const int total_size, const int cdf_group_index, + const int start_offset, const int per_channel_size) +{ + m_decoded = std::make_shared>(total_size); + + int8_t* outout_ptr = m_decoded->data(); + const int32_t* cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data(); + const int32_t* offsets_ptr = _offsets[cdf_group_index]->data(); + const auto& cdfs = _cdfs[cdf_group_index]; + for (int i = 0; i < total_size; ++i) { + const int32_t cdf_idx = i / per_channel_size + start_offset; + outout_ptr[i] = decode_one_symbol(cdfs->at(cdf_idx).data(), cdfs_sizes_ptr[cdf_idx], + offsets_ptr[cdf_idx]); + } +} + +std::shared_ptr> RansDecoderLib::get_decoded_tensor() +{ + return m_decoded; +} + +RansDecoderLibMultiThread::RansDecoderLibMultiThread() + : RansDecoderLib() + , m_finish(false) + , m_result_ready(false) + , m_thread(std::thread(&RansDecoderLibMultiThread::worker, this)) +{ +} + +RansDecoderLibMultiThread::~RansDecoderLibMultiThread() +{ + { + std::lock_guard lk(m_mutex_pending); + std::lock_guard lk1(m_mutex_result); + m_finish = true; + } + m_cv_pending.notify_one(); + m_cv_result.notify_one(); + m_thread.join(); +} + +void RansDecoderLibMultiThread::decode_y(const std::shared_ptr> indexes, + const int cdf_group_index) +{ + { + std::lock_guard lk(m_mutex_result); + m_result_ready = false; + } + PendingTask p; + p.workType = WorkType::EncodeDecodeY; + p.indexes = indexes; + p.cdf_group_index = cdf_group_index; + { + std::unique_lock lk(m_mutex_pending); + m_pending.push_back(p); + } + m_cv_pending.notify_one(); +} + +void RansDecoderLibMultiThread::decode_z(const int total_size, const int cdf_group_index, + const int start_offset, const int per_channel_size) +{ + { + std::lock_guard lk(m_mutex_result); + m_result_ready = false; + } + PendingTask p; + p.workType = WorkType::EncodeDecodeZ; + p.total_size = total_size; + p.cdf_group_index = cdf_group_index; + p.start_offset = start_offset; + p.per_channel_size = per_channel_size; + { + std::unique_lock lk(m_mutex_pending); + m_pending.push_back(p); + } + m_cv_pending.notify_one(); +} + +std::shared_ptr> RansDecoderLibMultiThread::get_decoded_tensor() +{ + std::unique_lock lk(m_mutex_result); + m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; }); + return RansDecoderLib::get_decoded_tensor(); +} + +void RansDecoderLibMultiThread::worker() +{ + while (!m_finish) { + std::unique_lock lk(m_mutex_pending); + m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; }); + if (m_finish) { + lk.unlock(); + break; + } + if (m_pending.size() == 0) { + lk.unlock(); + // std::cout << "contine in worker" << std::endl; + continue; + } + while (m_pending.size() > 0) { + auto p = m_pending.front(); + m_pending.pop_front(); + lk.unlock(); + if (p.workType == WorkType::EncodeDecodeY) { + RansDecoderLib::decode_y(p.indexes, p.cdf_group_index); + } else if (p.workType == WorkType::EncodeDecodeZ) { + RansDecoderLib::decode_z(p.total_size, p.cdf_group_index, p.start_offset, + p.per_channel_size); + } + { + std::lock_guard lk_result(m_mutex_result); + m_result_ready = true; + } + m_cv_result.notify_one(); + lk.lock(); + } + lk.unlock(); + } +} diff --git a/DCVC-RT/src/cpp/py_rans/rans.h b/DCVC-RT/src/cpp/py_rans/rans.h new file mode 100644 index 0000000..995dfee --- /dev/null +++ b/DCVC-RT/src/cpp/py_rans/rans.h @@ -0,0 +1,201 @@ +/* Copyright 2020 InterDigital Communications, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + #pragma GCC diagnostic ignored "-Wsign-compare" +#endif + +#ifdef _MSC_VER + #pragma warning(disable : 4244) +#endif + +#include "rans_byte.h" + +#ifdef _MSC_VER + #pragma warning(default : 4244) +#endif + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif + +#ifdef _MSC_VER + #define FORCE_INLINE __forceinline +#endif + +#ifdef __GNUC__ + #define FORCE_INLINE __attribute__((always_inline)) inline +#endif + +struct RansSymbol { + uint16_t start; + uint16_t range; // range for normal coding and 0 for bypass coding +}; + +enum class WorkType { + EncodeDecodeY, + EncodeDecodeZ, + Flush, +}; + +struct PendingTask { + WorkType workType; + std::shared_ptr> symbols_y; + std::shared_ptr> symbols_z; + std::shared_ptr> indexes; + int total_size{ 0 }; + int cdf_group_index{ 0 }; + int start_offset{ 0 }; + int per_channel_size{ 0 }; +}; + +/* NOTE: Warning, we buffer everything for now... In case of large files we + * should split the bitstream into chunks... Or for a memory-bounded encoder + **/ +class RansEncoderLib { +public: + RansEncoderLib(); + virtual ~RansEncoderLib() = default; + + RansEncoderLib(const RansEncoderLib&) = delete; + RansEncoderLib(RansEncoderLib&&) = delete; + RansEncoderLib& operator=(const RansEncoderLib&) = delete; + RansEncoderLib& operator=(RansEncoderLib&&) = delete; + + void encode_y(const std::shared_ptr> symbols, const int cdf_group_index); + void encode_z(const std::shared_ptr> symbols, const int cdf_group_index, + const int start_offset, const int per_channel_size); + + FORCE_INLINE void encode_y_internal(uint8_t*& ptr, RansState& rans, + const std::shared_ptr> symbols, + const int cdf_group_index); + FORCE_INLINE void encode_z_internal(uint8_t*& ptr, RansState& rans, + const std::shared_ptr> symbols, + const int cdf_group_index, const int start_offset, + const int per_channel_size); + FORCE_INLINE void encode_one_symbol(uint8_t*& ptr, RansState& rans, const int32_t symbol, + const int32_t cdf_size, const int32_t offset, + const std::vector& ransSymbols); + virtual void flush(); + virtual std::shared_ptr> get_encoded_stream(); + virtual void reset(); + virtual int add_cdf(const std::shared_ptr>> cdfs, + const std::shared_ptr> cdfs_sizes, + const std::shared_ptr> offsets); + virtual void empty_cdf_buffer(); + +private: + std::shared_ptr> _stream; + + std::vector>>> _ransSymbols; + std::vector>> _cdfs_sizes; + std::vector>> _offsets; + + std::list m_pendingEncodingList; +}; + +class RansEncoderLibMultiThread : public RansEncoderLib { +public: + RansEncoderLibMultiThread(); + virtual ~RansEncoderLibMultiThread(); + virtual void flush() override; + virtual std::shared_ptr> get_encoded_stream() override; + virtual void reset() override; + + void worker(); + +private: + bool m_finish; + bool m_result_ready; + std::thread m_thread; + std::mutex m_mutex_result; + std::mutex m_mutex_pending; + std::condition_variable m_cv_pending; + std::condition_variable m_cv_result; + std::list m_pending; +}; + +class RansDecoderLib { +public: + RansDecoderLib() {} + virtual ~RansDecoderLib() = default; + + RansDecoderLib(const RansDecoderLib&) = delete; + RansDecoderLib(RansDecoderLib&&) = delete; + RansDecoderLib& operator=(const RansDecoderLib&) = delete; + RansDecoderLib& operator=(RansDecoderLib&&) = delete; + + virtual void set_stream(const std::shared_ptr> encoded); + + FORCE_INLINE int8_t decode_one_symbol(const int32_t* cdf, const int32_t cdf_size, + const int32_t offset); + + virtual void decode_y(const std::shared_ptr> indexes, + const int cdf_group_index); + virtual void decode_z(const int total_size, const int cdf_group_index, const int start_offset, + const int per_channel_size); + + virtual std::shared_ptr> get_decoded_tensor(); + + virtual int add_cdf(const std::shared_ptr>> cdfs, + const std::shared_ptr> cdfs_sizes, + const std::shared_ptr> offsets); + virtual void empty_cdf_buffer(); + +private: + RansState _rans; + uint8_t* _ptr8; + std::shared_ptr> _stream; + std::shared_ptr> m_decoded; + + std::vector>>> _cdfs; + std::vector>> _cdfs_sizes; + std::vector>> _offsets; +}; + +class RansDecoderLibMultiThread : public RansDecoderLib { +public: + RansDecoderLibMultiThread(); + virtual ~RansDecoderLibMultiThread(); + + virtual void decode_y(const std::shared_ptr> indexes, + const int cdf_group_index) override; + + virtual void decode_z(const int total_size, const int cdf_group_index, const int start_offset, + const int per_channel_size) override; + + virtual std::shared_ptr> get_decoded_tensor() override; + + void worker(); + +private: + bool m_finish; + bool m_result_ready; + std::thread m_thread; + std::mutex m_mutex_result; + std::mutex m_mutex_pending; + std::condition_variable m_cv_pending; + std::condition_variable m_cv_result; + std::list m_pending; +}; \ No newline at end of file diff --git a/DCVC-RT/src/cpp/py_rans/rans_byte.h b/DCVC-RT/src/cpp/py_rans/rans_byte.h new file mode 100644 index 0000000..9c77e31 --- /dev/null +++ b/DCVC-RT/src/cpp/py_rans/rans_byte.h @@ -0,0 +1,141 @@ +// The code is from https://github.com/rygorous/ryg_rans +// The original lisence is below. + +// To the extent possible under law, Fabian Giesen has waived all +// copyright and related or neighboring rights to ryg_rans, as +// per the terms of the CC0 license: + +// https://creativecommons.org/publicdomain/zero/1.0 + +// This work is published from the United States. + +// Simple byte-aligned rANS encoder/decoder - public domain - Fabian 'ryg' +// Giesen 2014 +// +// Not intended to be "industrial strength"; just meant to illustrate the +// general idea. + +#pragma once + +#include + +#ifdef assert + #define RansAssert assert +#else + #define RansAssert(x) +#endif + +// READ ME FIRST: +// +// This is designed like a typical arithmetic coder API, but there's three +// twists you absolutely should be aware of before you start hacking: +// +// 1. You need to encode data in *reverse* - last symbol first. rANS works +// like a stack: last in, first out. +// 2. Likewise, the encoder outputs bytes *in reverse* - that is, you give +// it a pointer to the *end* of your buffer (exclusive), and it will +// slowly move towards the beginning as more bytes are emitted. +// 3. Unlike basically any other entropy coder implementation you might +// have used, you can interleave data from multiple independent rANS +// encoders into the same bytestream without any extra signaling; +// you can also just write some bytes by yourself in the middle if +// you want to. This is in addition to the usual arithmetic encoder +// property of being able to switch models on the fly. Writing raw +// bytes can be useful when you have some data that you know is +// incompressible, and is cheaper than going through the rANS encode +// function. Using multiple rANS coders on the same byte stream wastes +// a few bytes compared to using just one, but execution of two +// independent encoders can happen in parallel on superscalar and +// Out-of-Order CPUs, so this can be *much* faster in tight decoding +// loops. +// +// This is why all the rANS functions take the write pointer as an +// argument instead of just storing it in some context struct. + +// -------------------------------------------------------------------------- + +// L ('l' in the paper) is the lower bound of our normalization interval. +// Between this and our byte-aligned emission, we use 31 (not 32!) bits. +// This is done intentionally because exact reciprocals for 31-bit uints +// fit in 32-bit uints: this permits some optimizations during encoding. +constexpr int SCALE_BITS = 16; +constexpr int RANS_SHIFT_BITS = 23; +constexpr uint32_t RANS_BYTE_L = (1u << RANS_SHIFT_BITS); +constexpr int ENC_RENORM_SHIFT_BITS = RANS_SHIFT_BITS - SCALE_BITS + 8; +constexpr int DEC_MASK = ((1u << SCALE_BITS) - 1); + +// State for a rANS encoder. Yep, that's all there is to it. +typedef uint32_t RansState; + +// Initialize a rANS encoder. +static inline void RansEncInit(RansState& r) +{ + r = RANS_BYTE_L; +} + +// Renormalize the encoder. Internal function. +static inline void RansEncRenorm(RansState& x, uint8_t*& ptr, uint32_t freq) +{ + const uint32_t x_max = freq << ENC_RENORM_SHIFT_BITS; + while (x >= x_max) { + *(--ptr) = static_cast(x & 0xff); + x >>= 8; + } +} + +// Encodes a single symbol with range start "start" and frequency "freq". +// All frequencies are assumed to sum to "1 << scale_bits", and the +// resulting bytes get written to ptr (which is updated). +// +// NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from +// beginning to end! Likewise, the output bytestream is written *backwards*: +// ptr starts pointing at the end of the output buffer and keeps decrementing. +static inline void RansEncPut(RansState& r, uint8_t*& ptr, uint32_t start, uint32_t freq) +{ + // renormalize + RansEncRenorm(r, ptr, freq); + + // x = C(s,x) + r = ((r / freq) << SCALE_BITS) + (r % freq) + start; +} + +// Flushes the rANS encoder. +static inline void RansEncFlush(const RansState& r, uint8_t*& ptr) +{ + ptr -= 4; + ptr[0] = (uint8_t)(r >> 0); + ptr[1] = (uint8_t)(r >> 8); + ptr[2] = (uint8_t)(r >> 16); + ptr[3] = (uint8_t)(r >> 24); +} + +// Initializes a rANS decoder. +// Unlike the encoder, the decoder works forwards as you'd expect. +static inline void RansDecInit(RansState& r, uint8_t*& ptr) +{ + r = (*ptr++) << 0; + r |= (*ptr++) << 8; + r |= (*ptr++) << 16; + r |= (*ptr++) << 24; +} + +// Returns the current cumulative frequency (map it to a symbol yourself!) +static inline uint32_t RansDecGet(RansState& r) +{ + return r & DEC_MASK; +} + +// Advances in the bit stream by "popping" a single symbol with range start +// "start" and frequency "freq". All frequencies are assumed to sum to "1 << +// scale_bits", and the resulting bytes get written to ptr (which is updated). +static inline void RansDecAdvance(RansState& r, uint8_t*& ptr, uint32_t start, uint32_t freq) +{ + + // s, x = D(x) + r = freq * (r >> SCALE_BITS) + (r & DEC_MASK) - start; + + // renormalize + while (r < RANS_BYTE_L) { + r = (r << 8) | *ptr++; + } +} diff --git a/DCVC-RT/src/cpp/setup.py b/DCVC-RT/src/cpp/setup.py new file mode 100644 index 0000000..7e1b036 --- /dev/null +++ b/DCVC-RT/src/cpp/setup.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import glob +import sys +from setuptools import setup +from pybind11.setup_helpers import Pybind11Extension, build_ext + + +if sys.platform == "win32": + extra_compile_args = ['/std:c++17', '/O2', '/W4', '/WX', '/wd4100'] + extra_link_args = [] +else: + extra_compile_args = ['-std=c++17', '-O3', '-fPIC', '-Wall', '-Wextra', '-Werror'] + extra_link_args = [] + + +setup( + name="MLCodec_extensions_cpp", + ext_modules=[ + Pybind11Extension( + name='MLCodec_extensions_cpp', + sources=glob.glob('py_rans/*.cpp'), + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ), + ], + cmdclass={"build_ext": build_ext}, + zip_safe=False, + python_requires=">=3.12", +) diff --git a/DCVC-RT/src/layers/cuda_inference.py b/DCVC-RT/src/layers/cuda_inference.py new file mode 100644 index 0000000..dee9836 --- /dev/null +++ b/DCVC-RT/src/layers/cuda_inference.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +import torch +import torch.nn.functional as F + + +CUSTOMIZED_CUDA_INFERENCE = False +try: + from inference_extensions_cuda import process_with_mask_cuda, combine_for_reading_2x_cuda, \ + restore_y_2x_cuda, restore_y_4x_cuda, build_index_dec_cuda, \ + round_and_to_int8_cuda, clamp_reciprocal_with_quant_cuda, bias_quant_cuda, \ + add_and_multiply_cuda, bias_pixel_shuffle_8_cuda, replicate_pad_cuda, \ + build_index_enc_cuda, DepthConvProxy, SubpelConv2xProxy # noqa: F401 + CUSTOMIZED_CUDA_INFERENCE = True +except Exception: # pylint: disable=W0718 + pass + + +if not CUSTOMIZED_CUDA_INFERENCE and 'SUPPRESS_CUSTOM_KERNEL_WARNING' not in os.environ: + print("cannot import cuda implementation for inference, fallback to pytorch.") + + +def round_and_to_int8(z): + if CUSTOMIZED_CUDA_INFERENCE and z.is_cuda: + z_int8 = round_and_to_int8_cuda(z) + return z, z_int8 + + z_hat = torch.clamp(torch.round(z), -128., 127.) + z_hat_write = z_hat.to(dtype=torch.int8) + return z_hat, z_hat_write + + +def clamp_reciprocal_with_quant(q_dec, y, min_val): + if CUSTOMIZED_CUDA_INFERENCE and q_dec.is_cuda: + # q_dec is not inplace modified at decoder side + q_dec = clamp_reciprocal_with_quant_cuda(q_dec, y, min_val) + return q_dec, y + + q_dec = torch.clamp_min(q_dec, min_val) + q_enc = torch.reciprocal(q_dec) + y = y * q_enc + return q_dec, y + + +def add_and_multiply(y_hat_0, y_hat_1, q_dec): + if CUSTOMIZED_CUDA_INFERENCE and y_hat_0.is_cuda: + add_and_multiply_cuda(y_hat_0, y_hat_1, q_dec) + return y_hat_0 + + y_hat = y_hat_0 + y_hat_1 + y_hat = y_hat * q_dec + return y_hat + + +def process_with_mask(y, scales, means, mask, force_zero_thres): + if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda: + thres = force_zero_thres if force_zero_thres is not None else -1. + return process_with_mask_cuda(y, scales, means, mask, thres) + + scales_hat = scales * mask + means_hat = means * mask + + y_res = (y - means_hat) * mask + y_q = torch.round(y_res) + if force_zero_thres is not None: + cond = scales_hat > force_zero_thres + y_q = y_q * cond + y_q = torch.clamp(y_q, -128., 127.) + y_hat = y_q + means_hat + + return y_res, y_q, y_hat, scales_hat + + +def combine_for_reading_2x(x, mask, inplace=False): + if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda and x.is_contiguous(): + B, C, H, W = x.shape + if inplace: + out = x[:, :C // 2, :, :] + else: + out = torch.empty((B, C // 2, H, W), dtype=x.dtype, layout=x.layout, device=x.device) + combine_for_reading_2x_cuda(out, x, mask) + return out + + x = x * mask + x0, x1 = x.chunk(2, 1) + return x0 + x1 + + +def restore_y_2x(y, means, mask): + if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda and y.is_contiguous(): + out = torch.empty_like(means) + restore_y_2x_cuda(out, y, means, mask) + return out + + return (torch.cat((y, y), dim=1) + means) * mask + + +def restore_y_2x_with_cat_after(y, means, mask, to_cat): + if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda and y.is_contiguous(): + B, C1, H, W = means.shape + C2 = to_cat.shape[1] + out = torch.empty((B, C1 + C2, H, W), dtype=means.dtype, layout=means.layout, + device=means.device) + restore_y_2x_cuda(out[:, :C1, :, :], y, means, mask) + out[:, C1:, :, :] = to_cat + return out[:, :C1, :, :], out + + out = (torch.cat((y, y), dim=1) + means) * mask + return out, torch.cat((out, to_cat), dim=1) + + +def restore_y_4x(y, means, mask): + if CUSTOMIZED_CUDA_INFERENCE and y.is_cuda and y.is_contiguous(): + out = torch.empty_like(means) + restore_y_4x_cuda(out, y, means, mask) + return out + + return (torch.cat((y, y, y, y), dim=1) + means) * mask + + +def build_index_dec(scales, scale_min, scale_max, log_scale_min, log_step_recip, skip_thres=None): + if CUSTOMIZED_CUDA_INFERENCE and scales.is_cuda: + out = torch.empty_like(scales, dtype=torch.uint8) + skip_cond = None + if skip_thres is not None: + skip_cond = torch.empty_like(scales, dtype=torch.bool) + else: + skip_thres = -1. + + build_index_dec_cuda(out, skip_cond, scales, scale_min, scale_max, log_scale_min, + log_step_recip, skip_thres) + return out, skip_cond + + skip_cond = None + if skip_thres is not None: + skip_cond = scales > skip_thres + scales = scales.clamp_(scale_min, scale_max) + indexes = (torch.log(scales) - log_scale_min) * log_step_recip + indexes = indexes.to(dtype=torch.uint8) + return indexes, skip_cond + + +def build_index_enc(symbols, scales, scale_min, scale_max, log_scale_min, + log_step_recip, skip_thres=None): + if CUSTOMIZED_CUDA_INFERENCE and scales.is_cuda: + out = torch.empty_like(scales, dtype=torch.int16) + skip_cond = None + if skip_thres is not None: + skip_cond = torch.empty_like(scales, dtype=torch.bool) + else: + skip_thres = -1. + + build_index_enc_cuda(out, skip_cond, symbols, scales, scale_min, scale_max, log_scale_min, + log_step_recip, skip_thres) + + out = out[skip_cond] + return out + + scales = scales.clamp_(scale_min, scale_max) + indexes = (torch.log(scales) - log_scale_min) * log_step_recip + indexes = indexes.to(dtype=torch.uint8) + symbols = symbols.to(dtype=torch.int16) + out = (symbols << 8) + indexes + out = out.to(dtype=torch.int16) + if skip_thres is not None: + skip_cond = scales > skip_thres + out = out[skip_cond] + return out + + +def replicate_pad(x, pad_b, pad_r): + if pad_b == 0 and pad_r == 0: + return x + if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda: + return replicate_pad_cuda(x, pad_b, pad_r) + return F.pad(x, (0, pad_r, 0, pad_b), mode="replicate") + + +def bias_pixel_shuffle_8(x, bias): + if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda: + B, C, H, W = x.shape + assert B == 1 + out = torch.empty((B, 3, H * 8, W * 8), dtype=x.dtype, device=x.device, layout=x.layout) + bias_pixel_shuffle_8_cuda(out, x, bias, C, H * W, W, True) + return out + + out = x + bias[None, :, None, None] + out = F.pixel_shuffle(out, 8) + out = torch.clamp(out, 0., 1.) + return out + + +def bias_quant(x, bias, quant_step): + if CUSTOMIZED_CUDA_INFERENCE and x.is_cuda: + bias_quant_cuda(x, bias, quant_step) + return x + + out = x + bias[None, :, None, None] + out = out * quant_step + return out diff --git a/DCVC-RT/src/layers/extensions/inference/bind.cpp b/DCVC-RT/src/layers/extensions/inference/bind.cpp new file mode 100644 index 0000000..0bc8e24 --- /dev/null +++ b/DCVC-RT/src/layers/extensions/inference/bind.cpp @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "def.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("process_with_mask_cuda", &process_with_mask_cuda); + m.def("combine_for_reading_2x_cuda", &combine_for_reading_2x_cuda); + m.def("restore_y_2x_cuda", &restore_y_2x_cuda); + m.def("restore_y_4x_cuda", &restore_y_4x_cuda); + m.def("build_index_dec_cuda", &build_index_dec_cuda); + m.def("build_index_enc_cuda", &build_index_enc_cuda); + m.def("bias_quant_cuda", &bias_quant_cuda); + m.def("round_and_to_int8_cuda", &round_and_to_int8_cuda); + m.def("clamp_reciprocal_with_quant_cuda", &clamp_reciprocal_with_quant_cuda); + m.def("add_and_multiply_cuda", &add_and_multiply_cuda); + m.def("bias_pixel_shuffle_8_cuda", &bias_pixel_shuffle_8_cuda); + m.def("replicate_pad_cuda", &replicate_pad_cuda); + m.def("bias_wsilu_depthwise_conv2d_cuda", &bias_wsilu_depthwise_conv2d_cuda); + + py::class_(m, "DepthConvProxy") + .def(py::init<>()) + .def("set_param", &DepthConvProxy::set_param) + .def("set_param_with_adaptor", &DepthConvProxy::set_param_with_adaptor) + .def("forward", &DepthConvProxy::forward) + .def("forward_with_quant_step", &DepthConvProxy::forward_with_quant_step) + .def("forward_with_cat", &DepthConvProxy::forward_with_cat); + + py::class_(m, "SubpelConv2xProxy") + .def(py::init<>()) + .def("set_param", &SubpelConv2xProxy::set_param) + .def("forward", &SubpelConv2xProxy::forward) + .def("forward_with_cat", &SubpelConv2xProxy::forward_with_cat); +} diff --git a/DCVC-RT/src/layers/extensions/inference/common.h b/DCVC-RT/src/layers/extensions/inference/common.h new file mode 100644 index 0000000..7a8bfed --- /dev/null +++ b/DCVC-RT/src/layers/extensions/inference/common.h @@ -0,0 +1,319 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include +#include + +// T maybe vector type, and may be different from t.dtype +template +struct GPUTensor1D { + GPUTensor1D(torch::Tensor& t) : ptr(static_cast(t.data_ptr())) {} + GPUTensor1D(const torch::Tensor& t) : ptr(static_cast(t.data_ptr())) {} + GPUTensor1D(T* t) : ptr(static_cast(t)) { assert(t == nullptr); } + + __device__ T& operator[](int idx) { return ptr[idx]; } + __device__ T& operator[](int idx) const { return ptr[idx]; } + + T* __restrict__ const ptr; +}; + +template +using Packed4DTensorAccessor32 = torch::PackedTensorAccessor32; +template +using Packed1DTensorAccessor32 = torch::PackedTensorAccessor32; + +struct __align__(8) Half4 +{ + c10::Half x; + c10::Half y; + c10::Half z; + c10::Half w; +}; + +struct __align__(4) bool4 +{ + bool x; + bool y; + bool z; + bool w; +}; + +__forceinline__ __device__ float4 make_vec4(const float& x, const float& y, const float& z, + const float& w) +{ + return make_float4(x, y, z, w); +} + +__forceinline__ __device__ Half4 make_vec4(const c10::Half& x, const c10::Half& y, + const c10::Half& z, const c10::Half& w) +{ + Half4 t; + t.x = x; + t.y = y; + t.z = z; + t.w = w; + return t; +} + +__forceinline__ __device__ Half4 make_Half4(const c10::Half& x, const c10::Half& y, + const c10::Half& z, const c10::Half& w) +{ + Half4 t; + t.x = x; + t.y = y; + t.z = z; + t.w = w; + return t; +} + +__forceinline__ __device__ bool4 make_vec4(const bool& x, const bool& y, const bool& z, const bool& w) +{ + bool4 t; + t.x = x; + t.y = y; + t.z = z; + t.w = w; + return t; +} + +__forceinline__ __device__ c10::Half round(const c10::Half& a) +{ + return static_cast(__half2int_rn(a)); +} + +template +__forceinline__ __device__ T round(const T& a) +{ + return make_vec4(round(a.x), round(a.y), round(a.z), round(a.w)); +} + +__forceinline__ __device__ int8_t to_int8(const float& a) +{ + return static_cast(a); +} + +__forceinline__ __device__ int8_t to_int8(const c10::Half& a) +{ + return static_cast(a); +} + +template +__forceinline__ __device__ char4 to_int8(const T& a) +{ + return make_char4(to_int8(a.x), to_int8(a.y), to_int8(a.z), to_int8(a.w)); +} + +__forceinline__ __device__ uint8_t to_uint8(const float& a) +{ + return static_cast(a); +} + +__forceinline__ __device__ uint8_t to_uint8(const c10::Half& a) +{ + return static_cast(__half2uint_rd(a)); +} + +template +__forceinline__ __device__ uchar4 to_uint8(const T& a) +{ + return make_uchar4(to_uint8(a.x), to_uint8(a.y), to_uint8(a.z), to_uint8(a.w)); +} + +__forceinline__ __device__ int16_t to_int16(const float& a) +{ + return static_cast(a); +} + +__forceinline__ __device__ int16_t to_int16(const c10::Half& a) +{ + return static_cast(__half2int_rd(a)); +} + +template +__forceinline__ __device__ short4 to_int16(const T& a) +{ + return make_short4(to_int16(a.x), to_int16(a.y), to_int16(a.z), to_int16(a.w)); +} + +__forceinline__ __device__ short4 operator<<(const short4& a, const int b) +{ + return make_short4(a.x << b, a.y << b, a.z << b, a.w << b); +} + +__forceinline__ __device__ c10::Half min(const c10::Half& a, const c10::Half& b) +{ + return __hmin(a, b); +} +__forceinline__ __device__ c10::Half max(const c10::Half& a, const c10::Half& b) +{ + return __hmax(a, b); +} + +__forceinline__ __device__ bool operator>(const c10::Half& a, const c10::Half& b) +{ + return __hgt(a, b); +} + +__forceinline__ __device__ bool operator<(const c10::Half& a, const c10::Half& b) +{ + return __hlt(a, b); +} + +__forceinline__ __device__ c10::Half log(const c10::Half& a) +{ + return hlog(a); +} + +__forceinline__ __device__ short4 operator+(const short4& a, const short4& b) +{ + return make_short4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +__forceinline__ __device__ float4 operator+(const float4& a, const float4& b) +{ + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +__forceinline__ __device__ Half4 operator+(const Half4& a, const Half4& b) +{ + return make_Half4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +template +__forceinline__ __device__ T operator-(const T& a, const T& b) +{ + return make_vec4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} + +__forceinline__ __device__ float4 operator-(const float4& a, const float& b) +{ + return make_vec4(a.x - b, a.y - b, a.z - b, a.w - b); +} + +__forceinline__ __device__ Half4 operator-(const Half4& a, const c10::Half& b) +{ + return make_vec4(a.x - b, a.y - b, a.z - b, a.w - b); +} + +__forceinline__ __device__ c10::Half operator*(const c10::Half& a, const bool b) +{ + return b ? a : static_cast(0.f); +} + +template +__forceinline__ __device__ T operator*(const T& a, const T& b) +{ + return make_vec4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} + +template +__forceinline__ __device__ T operator*(const T& a, const bool4& b) +{ + return make_vec4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} + +template +__forceinline__ __device__ T1 operator*(const T1& a, const T2& b) +{ + return make_vec4(a.x * b, a.y * b, a.z * b, a.w * b); +} + +template +__forceinline__ __device__ T max(const T& a, const T& b) +{ + return make_vec4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); +} + +template +__forceinline__ __device__ T1 max(const T1& a, const T2& b) +{ + return make_vec4(max(a.x, b), max(a.y, b), max(a.z, b), max(a.w, b)); +} + +template +__forceinline__ __device__ T min(const T& a, const T& b) +{ + return make_vec4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); +} + +template +__forceinline__ __device__ T1 min(const T1& a, const T2& b) +{ + return make_vec4(min(a.x, b), min(a.y, b), min(a.z, b), min(a.w, b)); +} + +template +__forceinline__ __device__ T log(const T& a) +{ + return make_vec4(log(a.x), log(a.y), log(a.z), log(a.w)); +} + +__forceinline__ __device__ float reciprocal(const float& a) +{ + return __frcp_rd(a); +} + +__forceinline__ __device__ c10::Half reciprocal(const c10::Half& a) +{ + return hrcp(a); +} + +template +__forceinline__ __device__ T reciprocal(const T& a) +{ + return make_vec4(reciprocal(a.x), reciprocal(a.y), reciprocal(a.z), reciprocal(a.w)); +} + +template +__forceinline__ __device__ bool4 operator>(const T1& a, const T2& b) +{ + return make_vec4(a.x > b, a.y > b, a.z > b, a.w > b); +} + +__forceinline__ __device__ float sigmoid(const float x) +{ + return 1.0f / (1.0f + expf(-x)); +} + +__forceinline__ __device__ float wsilu(const float x) +{ + return x * sigmoid(4.0f * x); +} + +__forceinline__ __device__ c10::Half wsilu(const c10::Half x) +{ + return __float2half_rn(wsilu(__half2float(x))); +} + +__forceinline__ __device__ float4 wsilu(float4 data) +{ + data.x = wsilu(data.x); + data.y = wsilu(data.y); + data.z = wsilu(data.z); + data.w = wsilu(data.w); + return data; +} + +__forceinline__ __device__ Half4 wsilu(Half4 data) +{ + data.x = wsilu(data.x); + data.y = wsilu(data.y); + data.z = wsilu(data.z); + data.w = wsilu(data.w); + return data; +} + +__forceinline__ __device__ float multiply_add(const float a, const float b, const float c) +{ + return __fmaf_rn(a, b, c); +} + +__forceinline__ __device__ c10::Half multiply_add(const c10::Half a, const c10::Half b, const c10::Half c) +{ + + return __hfma(a, b, c); +} \ No newline at end of file diff --git a/DCVC-RT/src/layers/extensions/inference/def.h b/DCVC-RT/src/layers/extensions/inference/def.h new file mode 100644 index 0000000..17e7782 --- /dev/null +++ b/DCVC-RT/src/layers/extensions/inference/def.h @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include + +std::tuple +process_with_mask_cuda(const torch::Tensor& y, const torch::Tensor& scales, const torch::Tensor& means, + const torch::Tensor& mask, const float force_zero_thres); + +void combine_for_reading_2x_cuda(torch::Tensor& out, const torch::Tensor& x, const torch::Tensor& mask); +void restore_y_2x_cuda(torch::Tensor& out, const torch::Tensor& y, const torch::Tensor& means, + const torch::Tensor& mask); +void restore_y_4x_cuda(torch::Tensor& out, const torch::Tensor& y, const torch::Tensor& means, + const torch::Tensor& mask); + +void build_index_dec_cuda(torch::Tensor& out, torch::optional& cond_out, + const torch::Tensor& scales, const float scale_min, const float scale_max, + const float log_scale_min, const float log_step_recip, + const float skip_thres); + +void build_index_enc_cuda(torch::Tensor& out, torch::optional& cond_out, + const torch::Tensor& symbols, const torch::Tensor& scales, + const float scale_min, const float scale_max, const float log_scale_min, + const float log_step_recip, const float skip_thres); + +void bias_wsilu_cuda(torch::Tensor& x, const torch::Tensor& bias); + +void bias_shortcut_cuda(torch::Tensor& x, const torch::Tensor& bias, const torch::Tensor& shortcut); +void bias_shortcut_no_inplace_cuda(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& bias, const torch::Tensor& shortcut); +void bias_shortcut_2_cuda(torch::Tensor& x, const torch::Tensor& bias, torch::Tensor& shortcut); +void bias_shortcut_with_quant_step_cuda(torch::Tensor& x, const torch::Tensor& bias, + const torch::Tensor& quant_step, const torch::Tensor& shortcut); + +void bias_quant_cuda(torch::Tensor& x, const torch::Tensor& bias, const torch::Tensor& quant_step); + +void bias_wsilu_chunk_add_cuda(torch::Tensor& x, const torch::Tensor& bias); + +void bias_pixel_shuffle_2_cuda(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& bias, const int C, const int N, const int W); +void bias_pixel_shuffle_8_cuda(torch::Tensor& out, const torch::Tensor& x, const torch::Tensor& bias, + const int C, const int N, const int W, bool clamp); +torch::Tensor replicate_pad_cuda(const torch::Tensor& x, const int padB, const int padR); + +torch::Tensor round_and_to_int8_cuda(torch::Tensor& z); +torch::Tensor clamp_reciprocal_with_quant_cuda(const torch::Tensor& q_dec, torch::Tensor& y, + const float min_val); +void add_and_multiply_cuda(torch::Tensor& x0, const torch::Tensor& x1, const torch::Tensor q); + +torch::Tensor bias_wsilu_depthwise_conv2d_cuda(const torch::Tensor& x, const torch::Tensor& weight, + const torch::Tensor& bias); + +class DepthConvProxy { +public: + DepthConvProxy() = default; + ~DepthConvProxy() = default; + + void set_param(const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias, + const torch::Tensor& dc_depth_conv_weight, + const torch::Tensor& dc_depth_conv_bias, const torch::Tensor& dc_conv2_weight, + const torch::Tensor& dc_conv2_bias, const torch::Tensor& ffn_conv1_weight, + const torch::Tensor& ffn_conv1_bias, const torch::Tensor& ffn_conv2_weight, + const torch::Tensor& ffn_conv2_bias, const bool shortcut); + void set_param_with_adaptor( + const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias, + const torch::Tensor& dc_depth_conv_weight, const torch::Tensor& dc_depth_conv_bias, + const torch::Tensor& dc_conv2_weight, const torch::Tensor& dc_conv2_bias, + const torch::Tensor& ffn_conv1_weight, const torch::Tensor& ffn_conv1_bias, + const torch::Tensor& ffn_conv2_weight, const torch::Tensor& ffn_conv2_bias, + const torch::Tensor& adaptor_weight, const torch::Tensor& adaptor_bias, const bool shortcut); + torch::Tensor forward(const torch::Tensor& x); + torch::Tensor forward_with_quant_step(const torch::Tensor& x, const torch::Tensor& quant_step); + torch::Tensor forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat, + const bool cat_at_front); + +private: + std::tuple forward_common(const torch::Tensor& x); + +private: + torch::Tensor _dc_conv1_weight; + torch::Tensor _dc_conv1_bias; + torch::Tensor _dc_depth_conv_weight; + torch::Tensor _dc_conv2_weight; + torch::Tensor _dc_conv2_bias; + torch::Tensor _ffn_conv1_weight; + torch::Tensor _ffn_conv1_bias; + torch::Tensor _ffn_conv2_weight; + torch::Tensor _ffn_conv2_bias; + bool _adaptor{ false }; + bool _shortcut{ false }; +}; + +class SubpelConv2xProxy { +public: + SubpelConv2xProxy() = default; + ~SubpelConv2xProxy() = default; + + void set_param(const torch::Tensor& weight, const torch::Tensor& bias, const int padding); + torch::Tensor forward(const torch::Tensor& x); + torch::Tensor forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat, + const bool cat_at_front); + +private: + torch::Tensor _weight; + torch::Tensor _bias; + int _padding{ 0 }; +}; diff --git a/DCVC-RT/src/layers/extensions/inference/impl.cpp b/DCVC-RT/src/layers/extensions/inference/impl.cpp new file mode 100644 index 0000000..6419fff --- /dev/null +++ b/DCVC-RT/src/layers/extensions/inference/impl.cpp @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "def.h" +namespace F = torch::nn::functional; + +void DepthConvProxy::set_param( + const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias, + const torch::Tensor& dc_depth_conv_weight, const torch::Tensor& dc_depth_conv_bias, + const torch::Tensor& dc_conv2_weight, const torch::Tensor& dc_conv2_bias, + const torch::Tensor& ffn_conv1_weight, const torch::Tensor& ffn_conv1_bias, + const torch::Tensor& ffn_conv2_weight, const torch::Tensor& ffn_conv2_bias, const bool shortcut) +{ + _dc_conv1_weight = dc_conv1_weight; + _dc_conv1_bias = dc_conv1_bias; + _dc_depth_conv_weight = dc_depth_conv_weight; + _dc_conv2_weight = dc_conv2_weight; + _dc_conv2_bias = F::conv2d(dc_depth_conv_bias.reshape({ 1, -1, 1, 1 }), dc_conv2_weight); + _dc_conv2_bias = _dc_conv2_bias.index({ 0, torch::indexing::Slice(), 0, 0 }) + dc_conv2_bias; + _ffn_conv1_weight = ffn_conv1_weight; + _ffn_conv1_bias = ffn_conv1_bias; + _ffn_conv2_weight = ffn_conv2_weight; + _ffn_conv2_bias = ffn_conv2_bias; + _shortcut = shortcut; +} + +void DepthConvProxy::set_param_with_adaptor( + const torch::Tensor& dc_conv1_weight, const torch::Tensor& dc_conv1_bias, + const torch::Tensor& dc_depth_conv_weight, const torch::Tensor& dc_depth_conv_bias, + const torch::Tensor& dc_conv2_weight, const torch::Tensor& dc_conv2_bias, + const torch::Tensor& ffn_conv1_weight, const torch::Tensor& ffn_conv1_bias, + const torch::Tensor& ffn_conv2_weight, const torch::Tensor& ffn_conv2_bias, + const torch::Tensor& adaptor_weight, const torch::Tensor& adaptor_bias, const bool shortcut) +{ + _dc_conv1_weight = F::conv2d(torch::transpose(adaptor_weight, 0, 1), dc_conv1_weight); + _dc_conv1_weight = torch::transpose(_dc_conv1_weight, 0, 1); + _dc_conv1_weight = torch::cat({ _dc_conv1_weight, adaptor_weight }, 0); + _dc_conv1_bias = F::conv2d(adaptor_bias.reshape({ 1, -1, 1, 1 }), dc_conv1_weight); + _dc_conv1_bias = _dc_conv1_bias.index({ 0, torch::indexing::Slice(), 0, 0 }) + dc_conv1_bias; + _dc_depth_conv_weight = dc_depth_conv_weight; + _dc_conv2_weight = dc_conv2_weight; + _dc_conv2_bias = F::conv2d(dc_depth_conv_bias.reshape({ 1, -1, 1, 1 }), dc_conv2_weight); + _dc_conv2_bias = _dc_conv2_bias.index({ 0, torch::indexing::Slice(), 0, 0 }) + dc_conv2_bias; + _dc_conv2_bias = _dc_conv2_bias + adaptor_bias; + _ffn_conv1_weight = ffn_conv1_weight; + _ffn_conv1_bias = ffn_conv1_bias; + _ffn_conv2_weight = ffn_conv2_weight; + _ffn_conv2_bias = ffn_conv2_bias; + _shortcut = shortcut; + _adaptor = true; +} + +std::tuple DepthConvProxy::forward_common(const torch::Tensor& x) +{ + auto identity = x; + // depthconv + torch::Tensor out; + if (_adaptor) { + // NOTE: Here we always fuse adaptor with the first conv1x1 (when even in_ch > out_ch). + // It brings larger MACs, but it faster on A100 due to lower memory cost. + auto out_identity = F::conv2d(identity, _dc_conv1_weight); + auto chunks = torch::chunk(out_identity, 2, 1); + out = chunks[0]; + identity = chunks[1]; + } else { + out = F::conv2d(identity, _dc_conv1_weight); + } + out = bias_wsilu_depthwise_conv2d_cuda(out, _dc_depth_conv_weight, _dc_conv1_bias); + out = F::conv2d(out, _dc_conv2_weight); + + if (_shortcut) { + bias_shortcut_2_cuda(out, _dc_conv2_bias, identity); + } else { + bias_shortcut_cuda(out, _dc_conv2_bias, identity); + identity = out; + } + // ffn + out = F::conv2d(out, _ffn_conv1_weight); + bias_wsilu_chunk_add_cuda(out, _ffn_conv1_bias); + out = F::conv2d(out, _ffn_conv2_weight); + return { out, identity }; +} + +torch::Tensor DepthConvProxy::forward(const torch::Tensor& x) +{ + auto [out, identity] = forward_common(x); + bias_shortcut_cuda(out, _ffn_conv2_bias, identity); + return out; +} + +torch::Tensor DepthConvProxy::forward_with_quant_step(const torch::Tensor& x, + const torch::Tensor& quant_step) +{ + auto [out, identity] = forward_common(x); + bias_shortcut_with_quant_step_cuda(out, _ffn_conv2_bias, quant_step, identity); + return out; +} + +torch::Tensor DepthConvProxy::forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat, + const bool cat_at_front) +{ + auto [t, identity] = forward_common(x); + + auto t_shape = t.sizes(); + auto B = t_shape[0]; + auto C = t_shape[1]; + auto H = t_shape[2]; + auto W = t_shape[3]; + auto add_ch = to_cat.sizes()[1]; + auto out = torch::empty({ B, C + add_ch, H, W }, t.options()); + if (cat_at_front) { + auto t_out = out.narrow(1, add_ch, C); + bias_shortcut_no_inplace_cuda(t_out, t, _ffn_conv2_bias, identity); + out.narrow(1, 0, add_ch) = to_cat; + } else { + auto t_out = out.narrow(1, 0, C); + bias_shortcut_no_inplace_cuda(t_out, t, _ffn_conv2_bias, identity); + out.narrow(1, C, add_ch) = to_cat; + } + return out; +} + +void SubpelConv2xProxy::set_param(const torch::Tensor& weight, const torch::Tensor& bias, + const int padding) +{ + _weight = weight; + _bias = bias; + _padding = padding; +} + +torch::Tensor SubpelConv2xProxy::forward(const torch::Tensor& x) +{ + auto t = F::conv2d(x, _weight, F::Conv2dFuncOptions().padding(_padding)); + auto t_shape = t.sizes(); + auto B = t_shape[0]; + auto C = t_shape[1]; + auto H = t_shape[2]; + auto W = t_shape[3]; + auto out = torch::empty({ B, C / 4, H * 2, W * 2 }, t.options()); + assert(B == 1); + bias_pixel_shuffle_2_cuda(out, t, _bias, C, H * W, W); + return out; +} + +torch::Tensor SubpelConv2xProxy::forward_with_cat(const torch::Tensor& x, const torch::Tensor& to_cat, + const bool cat_at_front) +{ + auto t = F::conv2d(x, _weight, F::Conv2dFuncOptions().padding(_padding)); + auto t_shape = t.sizes(); + auto B = t_shape[0]; + auto C = t_shape[1]; + auto H = t_shape[2]; + auto W = t_shape[3]; + auto add_ch = to_cat.sizes()[1]; + auto out = torch::empty({ B, add_ch + C / 4, H * 2, W * 2 }, t.options()); + assert(B == 1); + if (cat_at_front) { + auto t_out = out.narrow(1, add_ch, C / 4); + bias_pixel_shuffle_2_cuda(t_out, t, _bias, C, H * W, W); + out.narrow(1, 0, add_ch) = to_cat; + } else { + auto t_out = out.narrow(1, 0, C / 4); + bias_pixel_shuffle_2_cuda(t_out, t, _bias, C, H * W, W); + out.narrow(1, C / 4, add_ch) = to_cat; + } + return out; +} diff --git a/DCVC-RT/src/layers/extensions/inference/kernel.cu b/DCVC-RT/src/layers/extensions/inference/kernel.cu new file mode 100644 index 0000000..41ebc78 --- /dev/null +++ b/DCVC-RT/src/layers/extensions/inference/kernel.cu @@ -0,0 +1,1150 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "def.h" +#include + +template +__forceinline__ __host__ bool can_vectorize(void* pointer) +{ + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4_alignment = std::alignment_of::value; + return address % vec4_alignment == 0; +} + +template +__forceinline__ std::tuple +get_kernel_launch_info(const torch::Tensor& x, const int cDiv = 1, const bool allow_useVec = true) +{ + const torch::IntArrayRef x_shape = x.sizes(); + const int B = x_shape[0]; + assert(B == 1); + const int C = x_shape[1]; + const int HW = x_shape[2] * x_shape[3]; + const int N = C * HW / cDiv; + const int BLOCK_SIZE = 128; + const dim3 blockDim(BLOCK_SIZE); + const bool useVec = allow_useVec && N % 4 == 0 && can_vectorize(x.data_ptr()); + const bool biasSafe = HW % 4 == 0; + const int factor = useVec ? 4 : 1; + const dim3 gridDim((N / factor + BLOCK_SIZE - 1) / BLOCK_SIZE); + return { blockDim, gridDim, at::cuda::getCurrentCUDAStream(), useVec, biasSafe, N / factor, HW }; +} + +template +__forceinline__ std::tuple +get_kernel_launch_info_flatten(const torch::Tensor& x) +{ + const int N = x.numel(); + const int BLOCK_SIZE = 128; + const dim3 blockDim(BLOCK_SIZE); + const bool useVec = N % 4 == 0 && can_vectorize(x.data_ptr()); + const int factor = useVec ? 4 : 1; + const dim3 gridDim((N / factor + BLOCK_SIZE - 1) / BLOCK_SIZE); + return { blockDim, gridDim, at::cuda::getCurrentCUDAStream(), useVec, N / factor }; +} + +template +__global__ void process_with_mask_kernel(GPUTensor1D y_res, GPUTensor1D y_q, + GPUTensor1D y_hat, GPUTensor1D s_hat, + const GPUTensor1D y, const GPUTensor1D scales, + const GPUTensor1D means, const GPUTensor1D mask, + const scalar_t force_zero_thres, const int N) +{ + const scalar_t __min_val = static_cast(-128.f); + const scalar_t __max_val = static_cast(127.f); + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + T _y = y[chw]; + T _scale = scales[chw]; + T _means = means[chw]; + T _mask = mask[chw]; + + T _s_hat = _scale * _mask; + T _means_hat = _means * _mask; + T _y_res = (_y - _means_hat) * _mask; + T _y_q = round(_y_res); + + if constexpr (forceZero) { + _y_q = _y_q * (_s_hat > force_zero_thres); + } + _y_q = max(min(_y_q, __max_val), __min_val); + T _y_hat = _y_q + _means_hat; + + y_res[chw] = _y_res; + y_q[chw] = _y_q; + y_hat[chw] = _y_hat; + s_hat[chw] = _s_hat; + } +} + +template +__forceinline__ void +process_with_mask_dispatcher(torch::Tensor& y_res, torch::Tensor& y_q, torch::Tensor& y_hat, + torch::Tensor& s_hat, const torch::Tensor& y, + const torch::Tensor& scales, const torch::Tensor& means, + const torch::Tensor& mask, const float force_zero_thres) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(y); + const bool force_zero = force_zero_thres > 0.f; + + auto launch_kernel = [&](auto in_v) { + using in_t = decltype(in_v); + if (force_zero) { + process_with_mask_kernel + <<>>(y_res, y_q, y_hat, s_hat, y, scales, means, mask, + static_cast(force_zero_thres), N); + } else { + process_with_mask_kernel + <<>>(y_res, y_q, y_hat, s_hat, y, scales, means, mask, + static_cast(force_zero_thres), N); + } + }; + + if (useVec) { + launch_kernel(vec_t{}); + } else { + launch_kernel(scalar_t{}); + } +} + +std::tuple +process_with_mask_cuda(const torch::Tensor& y, const torch::Tensor& scales, const torch::Tensor& means, + const torch::Tensor& mask, const float force_zero_thres) +{ + auto y_res = torch::empty_like(y); + auto y_q = torch::empty_like(y); + auto y_hat = torch::empty_like(y); + auto s_hat = torch::empty_like(y); + + if (y.dtype() == torch::kFloat32) { + process_with_mask_dispatcher(y_res, y_q, y_hat, s_hat, y, scales, means, + mask, force_zero_thres); + } else if (y.dtype() == torch::kFloat16) { + process_with_mask_dispatcher(y_res, y_q, y_hat, s_hat, y, scales, means, + mask, force_zero_thres); + } + + return { y_res, y_q, y_hat, s_hat }; +} + +template +__global__ void combine_for_reading_2x_kernel(GPUTensor1D out, const GPUTensor1D x, + const GPUTensor1D mask, const int N) +{ + const int chw1 = blockIdx.x * blockDim.x + threadIdx.x; + const int chw2 = chw1 + N; + + if (chw1 < N) { + T _s1 = x[chw1]; + T _s2 = x[chw2]; + T _m1 = mask[chw1]; + T _m2 = mask[chw2]; + + _s1 = _s1 * _m1; + _s2 = _s2 * _m2; + out[chw1] = _s1 + _s2; + } +} + +template +__forceinline__ void combine_for_reading_2x_dispatcher(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& mask) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(x, 2); + if (useVec) { + combine_for_reading_2x_kernel<<>>(out, x, mask, N); + } else { + combine_for_reading_2x_kernel<<>>(out, x, mask, N); + } +} + +void combine_for_reading_2x_cuda(torch::Tensor& out, const torch::Tensor& x, const torch::Tensor& mask) +{ + if (x.dtype() == torch::kFloat32) { + combine_for_reading_2x_dispatcher(out, x, mask); + } else if (x.dtype() == torch::kFloat16) { + combine_for_reading_2x_dispatcher(out, x, mask); + } +} + +template +__global__ void restore_y_2x_kernel(GPUTensor1D out, const GPUTensor1D y, + const GPUTensor1D means, const GPUTensor1D mask, const int N) +{ + const int chw1 = blockIdx.x * blockDim.x + threadIdx.x; + const int chw2 = chw1 + N; + + if (chw1 < N) { + T _y = y[chw1]; + T _means1 = means[chw1]; + T _means2 = means[chw2]; + T _mask1 = mask[chw1]; + T _mask2 = mask[chw2]; + + _means1 = (_y + _means1) * _mask1; + _means2 = (_y + _means2) * _mask2; + out[chw1] = _means1; + out[chw2] = _means2; + } +} + +template +__forceinline__ void restore_y_2x_dispatcher(torch::Tensor& out, const torch::Tensor& y, + const torch::Tensor& means, const torch::Tensor& mask) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(y); + if (useVec) { + restore_y_2x_kernel<<>>(out, y, means, mask, N); + } else { + restore_y_2x_kernel<<>>(out, y, means, mask, N); + } +} + +void restore_y_2x_cuda(torch::Tensor& out, const torch::Tensor& y, const torch::Tensor& means, + const torch::Tensor& mask) +{ + if (y.dtype() == torch::kFloat32) { + restore_y_2x_dispatcher(out, y, means, mask); + } else if (y.dtype() == torch::kFloat16) { + restore_y_2x_dispatcher(out, y, means, mask); + } +} + +template +__global__ void restore_y_4x_kernel(GPUTensor1D out, const GPUTensor1D y, + const GPUTensor1D means, const GPUTensor1D mask, const int N) +{ + const int chw1 = blockIdx.x * blockDim.x + threadIdx.x; + const int chw2 = chw1 + N; + const int chw3 = chw2 + N; + const int chw4 = chw3 + N; + + if (chw1 < N) { + T _y = y[chw1]; + T _means1 = means[chw1]; + T _means2 = means[chw2]; + T _means3 = means[chw3]; + T _means4 = means[chw4]; + T _mask1 = mask[chw1]; + T _mask2 = mask[chw2]; + T _mask3 = mask[chw3]; + T _mask4 = mask[chw4]; + + _means1 = (_y + _means1) * _mask1; + _means2 = (_y + _means2) * _mask2; + _means3 = (_y + _means3) * _mask3; + _means4 = (_y + _means4) * _mask4; + out[chw1] = _means1; + out[chw2] = _means2; + out[chw3] = _means3; + out[chw4] = _means4; + } +} + +template +__forceinline__ void restore_y_4x_dispatcher(torch::Tensor& out, const torch::Tensor& y, + const torch::Tensor& means, const torch::Tensor& mask) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(y); + if (useVec) { + restore_y_4x_kernel<<>>(out, y, means, mask, N); + } else { + restore_y_4x_kernel<<>>(out, y, means, mask, N); + } +} + +void restore_y_4x_cuda(torch::Tensor& out, const torch::Tensor& y, const torch::Tensor& means, + const torch::Tensor& mask) +{ + if (y.dtype() == torch::kFloat32) { + restore_y_4x_dispatcher(out, y, means, mask); + } else if (y.dtype() == torch::kFloat16) { + restore_y_4x_dispatcher(out, y, means, mask); + } +} + +template +__forceinline__ __device__ T scale_to_index(T scale, const scalar_t scale_min, + const scalar_t scale_max, const scalar_t log_scale_min, + const scalar_t log_step_recip) +{ + scale = max(scale, scale_min); + scale = min(scale, scale_max); + scale = log(scale) - log_scale_min; + scale = scale * log_step_recip; + return scale; +} + +template +__global__ void build_index_dec_kernel(GPUTensor1D out, GPUTensor1D cond_out, + const GPUTensor1D scales, const scalar_t scale_min, + const scalar_t scale_max, const scalar_t log_scale_min, + const scalar_t log_step_recip, const scalar_t skip_thres, + const int N) +{ + const int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (n < N) { + in_t _scale = scales[n]; + in_t _index = scale_to_index(_scale, scale_min, scale_max, log_scale_min, log_step_recip); + out[n] = to_uint8(_index); + if constexpr (with_cond) { + cond_out_t _cond = _scale > skip_thres; + cond_out[n] = _cond; + } + } +} + +template +__forceinline__ void +build_index_dec_dispatcher(torch::Tensor& out, torch::optional& cond_out, + const torch::Tensor& scales, const scalar_t scale_min, + const scalar_t scale_max, const scalar_t log_scale_min, + const scalar_t log_step_recip, const scalar_t skip_thres) +{ + auto [blockDim, gridDim, stream, useVec, N] = get_kernel_launch_info_flatten(scales); + const bool with_cond = static_cast(skip_thres) > 0.f; + + auto launch_kernel = [&](auto in_v, auto out_v, auto cond_out_v) { + using in_t = decltype(in_v); + using out_t = decltype(out_v); + using cond_out_t = decltype(cond_out_v); + if (with_cond) { + build_index_dec_kernel + <<>>(out, cond_out.value(), scales, scale_min, scale_max, + log_scale_min, log_step_recip, skip_thres, N); + } else { + build_index_dec_kernel + <<>>(out, nullptr, scales, scale_min, scale_max, + log_scale_min, log_step_recip, skip_thres, N); + } + }; + + if (useVec) { + launch_kernel(vec_t{}, uchar4{}, bool4{}); + } else { + launch_kernel(scalar_t{}, uint8_t{}, bool{}); + } +} + +void build_index_dec_cuda(torch::Tensor& out, torch::optional& cond_out, + const torch::Tensor& scales, const float scale_min, const float scale_max, + const float log_scale_min, const float log_step_recip, const float skip_thres) +{ + if (scales.dtype() == torch::kFloat32) { + build_index_dec_dispatcher(out, cond_out, scales, scale_min, scale_max, + log_scale_min, log_step_recip, skip_thres); + } else if (scales.dtype() == torch::kFloat16) { + build_index_dec_dispatcher( + out, cond_out, scales, static_cast(scale_min), + static_cast(scale_max), static_cast(log_scale_min), + static_cast(log_step_recip), static_cast(skip_thres)); + } +} + +template +__global__ void build_index_enc_kernel(GPUTensor1D out, GPUTensor1D cond_out, + const GPUTensor1D symbols, const GPUTensor1D scales, + const scalar_t scale_min, const scalar_t scale_max, + const scalar_t log_scale_min, const scalar_t log_step_recip, + const scalar_t skip_thres, const int N) +{ + const int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (n < N) { + in_t _scale = scales[n]; + in_t _symbol = symbols[n]; + in_t _index = scale_to_index(_scale, scale_min, scale_max, log_scale_min, log_step_recip); + + out[n] = (to_int16(_symbol) << 8) + to_int16(_index); + if constexpr (with_cond) { + cond_out_t _cond = _scale > skip_thres; + cond_out[n] = _cond; + } + } +} + +template +__forceinline__ void build_index_enc_dispatcher( + torch::Tensor& out, torch::optional& cond_out, const torch::Tensor& symbols, + const torch::Tensor& scales, const scalar_t scale_min, const scalar_t scale_max, + const scalar_t log_scale_min, const scalar_t log_step_recip, const scalar_t skip_thres) +{ + auto [blockDim, gridDim, stream, useVec, N] = get_kernel_launch_info_flatten(scales); + const bool with_cond = static_cast(skip_thres) > 0.f; + + auto launch_kernel = [&](auto in_v, auto out_v, auto cond_out_v) { + using in_t = decltype(in_v); + using out_t = decltype(out_v); + using cond_out_t = decltype(cond_out_v); + if (with_cond) { + build_index_enc_kernel + <<>>(out, cond_out.value(), symbols, scales, + scale_min, scale_max, log_scale_min, + log_step_recip, skip_thres, N); + } else { + build_index_enc_kernel + <<>>(out, nullptr, symbols, scales, scale_min, scale_max, + log_scale_min, log_step_recip, skip_thres, N); + } + }; + + if (useVec) { + launch_kernel(vec_t{}, short4{}, bool4{}); + } else { + launch_kernel(scalar_t{}, int16_t{}, bool{}); + } +} + +void build_index_enc_cuda(torch::Tensor& out, torch::optional& cond_out, + const torch::Tensor& symbols, const torch::Tensor& scales, + const float scale_min, const float scale_max, const float log_scale_min, + const float log_step_recip, const float skip_thres) +{ + if (scales.dtype() == torch::kFloat32) { + build_index_enc_dispatcher(out, cond_out, symbols, scales, scale_min, scale_max, + log_scale_min, log_step_recip, skip_thres); + } else if (scales.dtype() == torch::kFloat16) { + build_index_enc_dispatcher( + out, cond_out, symbols, scales, static_cast(scale_min), + static_cast(scale_max), static_cast(log_scale_min), + static_cast(log_step_recip), static_cast(skip_thres)); + } +} + +template +__forceinline__ __device__ vec_t get_bias(const GPUTensor1D bias, const int HW, const int chw) +{ + vec_t _bias; + if constexpr (sizeof(vec_t) / sizeof(scalar_t) == 4) { + if constexpr (biasSafe) { + scalar_t b = bias[(chw * 4 + 0) / HW]; + _bias = make_vec4(b, b, b, b); + } else { + _bias = make_vec4(bias[(chw * 4 + 0) / HW], bias[(chw * 4 + 1) / HW], + bias[(chw * 4 + 2) / HW], bias[(chw * 4 + 3) / HW]); + } + } else { + _bias = bias[(chw) / HW]; + } + return _bias; +} + +template +__global__ void bias_wsilu_kernel(GPUTensor1D x, const GPUTensor1D bias, + const int N, const int HW) +{ + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + vec_t _bias = get_bias(bias, HW, chw); + vec_t _x = x[chw]; + _x = _x + _bias; + _x = wsilu(_x); + x[chw] = _x; + } +} + +template +__forceinline__ void bias_wsilu_dispatcher(torch::Tensor& x, const torch::Tensor& bias) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(x); + if (useVec) { + if (biasSafe) { + bias_wsilu_kernel<<>>(x, bias, N, HW); + } else { + bias_wsilu_kernel<<>>(x, bias, N, HW); + } + } else { + bias_wsilu_kernel<<>>(x, bias, N, HW); + } +} + +void bias_wsilu_cuda(torch::Tensor& x, const torch::Tensor& bias) +{ + if (x.dtype() == torch::kFloat32) { + bias_wsilu_dispatcher(x, bias); + } else if (x.dtype() == torch::kFloat16) { + bias_wsilu_dispatcher(x, bias); + } +} + +template +__global__ void bias_shortcut_kernel(GPUTensor1D x, const GPUTensor1D bias, + const GPUTensor1D quant_step, + const GPUTensor1D shortcut, const int N, const int HW) +{ + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + vec_t _x = x[chw]; + vec_t _bias = get_bias(bias, HW, chw); + _x = _x + _bias; + if constexpr (with_shortcut) { + vec_t _s = shortcut[chw]; + _x = _x + _s; + } + if constexpr (with_quant) { + vec_t _q = get_bias(quant_step, HW, chw); + _x = _x * _q; + } + x[chw] = _x; + } +} + +template +__forceinline__ void bias_shortcut_dispatcher(torch::Tensor& x, const torch::Tensor& bias, + const torch::Tensor& quant_step, + const torch::Tensor& shortcut) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(x); + if (useVec) { + if (biasSafe) { + bias_shortcut_kernel + <<>>(x, bias, quant_step, shortcut, N, HW); + } else { + bias_shortcut_kernel + <<>>(x, bias, quant_step, shortcut, N, HW); + } + } else { + bias_shortcut_kernel + <<>>(x, bias, quant_step, shortcut, N, HW); + } +} + +void bias_shortcut_cuda(torch::Tensor& x, const torch::Tensor& bias, const torch::Tensor& shortcut) +{ + if (x.dtype() == torch::kFloat32) { + bias_shortcut_dispatcher(x, bias, bias, shortcut); + } else if (x.dtype() == torch::kFloat16) { + bias_shortcut_dispatcher(x, bias, bias, shortcut); + } +} + +void bias_quant_cuda(torch::Tensor& x, const torch::Tensor& bias, const torch::Tensor& quant_step) +{ + if (x.dtype() == torch::kFloat32) { + bias_shortcut_dispatcher(x, bias, quant_step, bias); + } else if (x.dtype() == torch::kFloat16) { + bias_shortcut_dispatcher(x, bias, quant_step, bias); + } +} + +template +__global__ void bias_shortcut_no_inplace_kernel(GPUTensor1D out, const GPUTensor1D x, + const GPUTensor1D bias, + const GPUTensor1D shortcut, const int N, + const int HW) +{ + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + vec_t _x = x[chw]; + vec_t _bias = get_bias(bias, HW, chw); + _x = _x + _bias; + vec_t _s = shortcut[chw]; + _x = _x + _s; + out[chw] = _x; + } +} + +template +__forceinline__ void bias_shortcut_no_inplace_dispatcher(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& bias, + const torch::Tensor& shortcut) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(x); + if (useVec) { + if (biasSafe) { + bias_shortcut_no_inplace_kernel + <<>>(out, x, bias, shortcut, N, HW); + } else { + bias_shortcut_no_inplace_kernel + <<>>(out, x, bias, shortcut, N, HW); + } + } else { + bias_shortcut_no_inplace_kernel + <<>>(out, x, bias, shortcut, N, HW); + } +} +void bias_shortcut_no_inplace_cuda(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& bias, const torch::Tensor& shortcut) +{ + if (x.dtype() == torch::kFloat32) { + bias_shortcut_no_inplace_dispatcher(out, x, bias, shortcut); + } else if (x.dtype() == torch::kFloat16) { + bias_shortcut_no_inplace_dispatcher(out, x, bias, shortcut); + } +} + +template +__global__ void bias_shortcut_2_kernel(GPUTensor1D x, const GPUTensor1D bias, + GPUTensor1D shortcut, const int N, const int HW) +{ + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + vec_t _x = x[chw]; + vec_t _bias = get_bias(bias, HW, chw); + _x = _x + _bias; + vec_t _s = shortcut[chw]; + _x = _x + _s; + x[chw] = _x; + shortcut[chw] = _x + _s; + } +} + +template +__forceinline__ void bias_shortcut_2_dispatcher(torch::Tensor& x, const torch::Tensor& bias, + torch::Tensor& shortcut) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(x); + if (useVec) { + if (biasSafe) { + bias_shortcut_2_kernel + <<>>(x, bias, shortcut, N, HW); + } else { + bias_shortcut_2_kernel + <<>>(x, bias, shortcut, N, HW); + } + } else { + bias_shortcut_2_kernel + <<>>(x, bias, shortcut, N, HW); + } +} + +void bias_shortcut_2_cuda(torch::Tensor& x, const torch::Tensor& bias, torch::Tensor& shortcut) +{ + if (x.dtype() == torch::kFloat32) { + bias_shortcut_2_dispatcher(x, bias, shortcut); + } else if (x.dtype() == torch::kFloat16) { + bias_shortcut_2_dispatcher(x, bias, shortcut); + } +} + +void bias_shortcut_with_quant_step_cuda(torch::Tensor& x, const torch::Tensor& bias, + const torch::Tensor& quant_step, const torch::Tensor& shortcut) +{ + if (x.dtype() == torch::kFloat32) { + bias_shortcut_dispatcher(x, bias, quant_step, shortcut); + } else if (x.dtype() == torch::kFloat16) { + bias_shortcut_dispatcher(x, bias, quant_step, shortcut); + } +} + +template +__global__ void bias_wsilu_chunk_add_kernel(GPUTensor1D x, const GPUTensor1D bias, + const int N, const int HW) +{ + const int chw1 = blockIdx.x * blockDim.x + threadIdx.x; + const int chw2 = chw1 + N; + + if (chw1 < N) { + vec_t _x1 = x[chw1]; + vec_t _bias1 = get_bias(bias, HW, chw1); + _x1 = _x1 + _bias1; + _x1 = wsilu(_x1); + + vec_t _x2 = x[chw2]; + vec_t _bias2 = get_bias(bias, HW, chw2); + _x2 = _x2 + _bias2; + _x2 = wsilu(_x2); + + x[chw1] = _x1 + _x2; + } +} + +template +__forceinline__ void bias_wsilu_chunk_add_dispatcher(torch::Tensor& x, const torch::Tensor& bias) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(x, 2); + if (useVec) { + if (biasSafe) { + bias_wsilu_chunk_add_kernel + <<>>(x, bias, N, HW); + } else { + bias_wsilu_chunk_add_kernel + <<>>(x, bias, N, HW); + } + } else { + bias_wsilu_chunk_add_kernel + <<>>(x, bias, N, HW); + } +} + +void bias_wsilu_chunk_add_cuda(torch::Tensor& x, const torch::Tensor& bias) +{ + if (x.dtype() == torch::kFloat32) { + bias_wsilu_chunk_add_dispatcher(x, bias); + } else if (x.dtype() == torch::kFloat16) { + bias_wsilu_chunk_add_dispatcher(x, bias); + } + const torch::IntArrayRef x_shape = x.sizes(); + x = x.narrow(1, 0, x_shape[1] / 2); +} + +template +__global__ void bias_pixel_shuffle_2_kernel(Packed4DTensorAccessor32 out, + const Packed4DTensorAccessor32 x, + const Packed1DTensorAccessor32 bias, + const int N, const int W) +{ + const int c = blockIdx.y * 4; + const int c1 = c / 4; + const int hw = blockIdx.x * blockDim.x + threadIdx.x; + const int h = (hw / W) * 2; + const int w = (hw % W) * 2; + + __shared__ scalar_t _bias[4]; + if (threadIdx.x < 4) { + _bias[threadIdx.x] = bias[c + threadIdx.x]; + } + __syncthreads(); + + if (hw < N) { + scalar_t _bias_0 = _bias[0]; + scalar_t _bias_1 = _bias[1]; + scalar_t _bias_2 = _bias[2]; + scalar_t _bias_3 = _bias[3]; + + scalar_t _x0 = x[0][c + 0][0][hw]; + scalar_t _x1 = x[0][c + 1][0][hw]; + scalar_t _x2 = x[0][c + 2][0][hw]; + scalar_t _x3 = x[0][c + 3][0][hw]; + + _x0 = _x0 + _bias_0; + _x1 = _x1 + _bias_1; + _x2 = _x2 + _bias_2; + _x3 = _x3 + _bias_3; + + out[0][c1][h + 0][w + 0] = _x0; + out[0][c1][h + 0][w + 1] = _x1; + out[0][c1][h + 1][w + 0] = _x2; + out[0][c1][h + 1][w + 1] = _x3; + } +} + +template +__forceinline__ void bias_pixel_shuffle_2_dispatcher(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& bias, const int C, + const int N, const int W) +{ + const int BLOCK_SIZE = 128; + const dim3 gridDim((N + BLOCK_SIZE - 1) / BLOCK_SIZE, C / 4); + const dim3 blockDim(BLOCK_SIZE); + auto stream = at::cuda::getCurrentCUDAStream(); + bias_pixel_shuffle_2_kernel<<>>( + out.packed_accessor32(), + x.packed_accessor32(), + bias.packed_accessor32(), N, W); +} + +void bias_pixel_shuffle_2_cuda(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& bias, const int C, const int N, const int W) +{ + if (x.dtype() == torch::kFloat32) { + bias_pixel_shuffle_2_dispatcher(out, x, bias, C, N, W); + } else if (x.dtype() == torch::kFloat16) { + bias_pixel_shuffle_2_dispatcher(out, x, bias, C, N, W); + } +} + +template +__global__ void bias_pixel_shuffle_8_kernel(Packed4DTensorAccessor32 out, + const Packed4DTensorAccessor32 x, + const Packed1DTensorAccessor32 bias, + const int N, const int W) +{ + const int c = blockIdx.y * 64; + const int c1 = c / 64; + const int hw = blockIdx.x * blockDim.x + threadIdx.x; + const int h = (hw / W) * 8; + const int w = (hw % W) * 8; + + __shared__ scalar_t _bias[64]; + if (threadIdx.x < 64) { + _bias[threadIdx.x] = bias[c + threadIdx.x]; + } + __syncthreads(); + + if (hw < N) { + for (int i = 0; i < 64; i++) { + scalar_t _x = x[0][c + i][0][hw]; + _x = _x + _bias[i]; + const int out_y_offset = i >> 3; + const int out_x_offset = i & 7; + if constexpr (clamp) { + _x = max(_x, static_cast(0.f)); + _x = min(_x, static_cast(1.f)); + } + out[0][c1][h + out_y_offset][w + out_x_offset] = _x; + } + } +} + +template +__forceinline__ void bias_pixel_shuffle_8_dispatcher(torch::Tensor& out, const torch::Tensor& x, + const torch::Tensor& bias, const int C, + const int N, const int W, bool clamp) +{ + const int BLOCK_SIZE = 128; + const dim3 gridDim((N + BLOCK_SIZE - 1) / BLOCK_SIZE, C / 64); + const dim3 blockDim(BLOCK_SIZE); + auto stream = at::cuda::getCurrentCUDAStream(); + if (clamp) { + bias_pixel_shuffle_8_kernel<<>>( + out.packed_accessor32(), + x.packed_accessor32(), + bias.packed_accessor32(), N, W); + } else { + bias_pixel_shuffle_8_kernel<<>>( + out.packed_accessor32(), + x.packed_accessor32(), + bias.packed_accessor32(), N, W); + } +} + +void bias_pixel_shuffle_8_cuda(torch::Tensor& out, const torch::Tensor& x, const torch::Tensor& bias, + const int C, const int N, const int W, bool clamp) +{ + if (x.dtype() == torch::kFloat32) { + bias_pixel_shuffle_8_dispatcher(out, x, bias, C, N, W, clamp); + } else if (x.dtype() == torch::kFloat16) { + bias_pixel_shuffle_8_dispatcher(out, x, bias, C, N, W, clamp); + } +} + +template +__global__ void round_and_to_int8_kernel(GPUTensor1D z, GPUTensor1D z_int8, const int N) +{ + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + vec_t1 _z = z[chw]; + _z = round(_z); + _z = max(_z, static_cast(-128.f)); + _z = min(_z, static_cast(127.f)); + z[chw] = _z; + vec_t2 _z_int8 = to_int8(_z); + z_int8[chw] = _z_int8; + } +} + +template +__forceinline__ void round_and_to_int8_dispatcher(torch::Tensor& z, torch::Tensor& z_int8) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(z); + if (useVec) { + round_and_to_int8_kernel + <<>>(z, z_int8, N); + } else { + round_and_to_int8_kernel + <<>>(z, z_int8, N); + } +} + +torch::Tensor round_and_to_int8_cuda(torch::Tensor& z) +{ + auto z_int8 = torch::empty_like(z, at::TensorOptions().dtype(torch::kInt8)); + if (z.dtype() == torch::kFloat32) { + round_and_to_int8_dispatcher(z, z_int8); + } else if (z.dtype() == torch::kFloat16) { + round_and_to_int8_dispatcher(z, z_int8); + } + return z_int8; +} + +template +__global__ void clamp_reciprocal_with_quant_kernel(GPUTensor1D q_dec_clamp, + const GPUTensor1D q_dec, GPUTensor1D y, + const scalar_t min_val, const int N) +{ + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + vec_t _q_dec = q_dec[chw]; + vec_t _y = y[chw]; + _q_dec = max(_q_dec, min_val); + q_dec_clamp[chw] = _q_dec; + vec_t _q_enc = reciprocal(_q_dec); + _y = _y * _q_enc; + y[chw] = _y; + } +} + +template +__forceinline__ void clamp_reciprocal_with_quant_dispatcher(torch::Tensor& q_dec_clamp, + const torch::Tensor& q_dec, + torch::Tensor& y, const float min_val) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(q_dec); + if (useVec) { + clamp_reciprocal_with_quant_kernel<<>>( + q_dec_clamp, q_dec, y, static_cast(min_val), N); + } else { + clamp_reciprocal_with_quant_kernel<<>>( + q_dec_clamp, q_dec, y, static_cast(min_val), N); + } +} + +torch::Tensor clamp_reciprocal_with_quant_cuda(const torch::Tensor& q_dec, torch::Tensor& y, + const float min_val) +{ + auto q_dec_clamp = torch::empty_like(q_dec); + if (q_dec.dtype() == torch::kFloat32) { + clamp_reciprocal_with_quant_dispatcher(q_dec_clamp, q_dec, y, min_val); + } else if (q_dec.dtype() == torch::kFloat16) { + clamp_reciprocal_with_quant_dispatcher(q_dec_clamp, q_dec, y, min_val); + } + return q_dec_clamp; +} + +template +__global__ void add_and_multiply_kernel(GPUTensor1D x0, const GPUTensor1D x1, + const GPUTensor1D q, const int N) +{ + const int chw = blockIdx.x * blockDim.x + threadIdx.x; + + if (chw < N) { + T _x0 = x0[chw]; + T _x1 = x1[chw]; + T _q = q[chw]; + _x0 = _x0 + _x1; + _x0 = _x0 * _q; + x0[chw] = _x0; + } +} + +template +__forceinline__ void add_and_multiply_dispatcher(torch::Tensor& x0, const torch::Tensor& x1, + const torch::Tensor& q) +{ + auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info(x0); + if (useVec) { + add_and_multiply_kernel<<>>(x0, x1, q, N); + } else { + add_and_multiply_kernel<<>>(x0, x1, q, N); + } +} + +void add_and_multiply_cuda(torch::Tensor& x0, const torch::Tensor& x1, const torch::Tensor q) +{ + if (x0.dtype() == torch::kFloat32) { + add_and_multiply_dispatcher(x0, x1, q); + } else if (x0.dtype() == torch::kFloat16) { + add_and_multiply_dispatcher(x0, x1, q); + } +} + +template +__global__ void replicate_pad_kernel(Packed4DTensorAccessor32 out, + const Packed4DTensorAccessor32 x, const int C, + const int H, const int W, const int H_padded, const int W_padded) +{ + const int b = blockIdx.y; + const int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (n < H_padded * W_padded) { + const int dst_y = n / W_padded; + const int dst_x = n % W_padded; + const int src_y = min(dst_y, H - 1); + const int src_x = min(dst_x, W - 1); + for (int i = 0; i < C; i++) { + scalar_t _x = x[b][i][src_y][src_x]; + out[b][i][dst_y][dst_x] = _x; + } + } +} + +template +__forceinline__ void replicate_pad_dispatcher(torch::Tensor& out, const torch::Tensor& x, + const int B, const int C, const int H, const int W, + const int padB, const int padR) +{ + const int totalOutPixel = (H + padB) * (W + padR); + const int BLOCK_SIZE = 128; + const dim3 blockDim(BLOCK_SIZE); + const dim3 gridDim((totalOutPixel + BLOCK_SIZE - 1) / BLOCK_SIZE, B); + auto stream = at::cuda::getCurrentCUDAStream(); + + replicate_pad_kernel<<>>( + out.packed_accessor32(), + x.packed_accessor32(), C, H, W, H + padB, W + padR); +} + +torch::Tensor replicate_pad_cuda(const torch::Tensor& x, const int padB, const int padR) +{ + const torch::IntArrayRef x_shape = x.sizes(); + const int B = x_shape[0]; + const int C = x_shape[1]; + const int H = x_shape[2]; + const int W = x_shape[3]; + auto out = torch::empty({ B, C, H + padB, W + padR }, x.options()); + if (x.dtype() == torch::kFloat32) { + replicate_pad_dispatcher(out, x, B, C, H, W, padB, padR); + } else if (x.dtype() == torch::kFloat16) { + replicate_pad_dispatcher(out, x, B, C, H, W, padB, padR); + } else if (x.dtype() == torch::kInt8) { + replicate_pad_dispatcher(out, x, B, C, H, W, padB, padR); + } else if (x.dtype() == torch::kInt16) { + replicate_pad_dispatcher(out, x, B, C, H, W, padB, padR); + } + return out; +} + +template +__global__ void bias_wsilu_depthwise_conv2d_kernel(Packed4DTensorAccessor32 out, + const Packed4DTensorAccessor32 x, + const Packed4DTensorAccessor32 weight, + const Packed1DTensorAccessor32 bias, + const int B, const int C, const int H, const int W) +{ + const int b = blockIdx.z / C; + const int c = blockIdx.z % C; + const int h = blockIdx.y * BLOCK_SIZE; // start of the block + const int w = blockIdx.x * BLOCK_SIZE; + const int THREAD_NUM = THREAD_NUM_Y * THREAD_NUM_X; + const int t_idx = threadIdx.y * THREAD_NUM_X + threadIdx.x; + + __shared__ T1 x_shared[BLOCK_SIZE + 2][BLOCK_SIZE + 2]; + const T1 __bias = static_cast(bias[c]); + T1 __weight[3][3]; +#pragma unroll + for (int i = 0; i < 3; i++) { +#pragma unroll + for (int j = 0; j < 3; j++) { + __weight[i][j] = static_cast(weight[c][0][i][j]); + } + } + + // load boundary padded pixels + const int read_times = (BLOCK_SIZE * 4 + THREAD_NUM - 1) / THREAD_NUM; + const int boundary_pos = BLOCK_SIZE + 1; + + for (int i = 0; i < read_times; i++) { + int pixel_idx = i * THREAD_NUM + t_idx; + if (pixel_idx < BLOCK_SIZE * 2) { + const int y_offset = pixel_idx / 2 + 1; + const int x_offset = (pixel_idx & 1) * boundary_pos; + const int curr_y = h + y_offset - 1; + const int curr_x = w + x_offset - 1; + if (curr_y < 0 || curr_x < 0 || curr_y >= H || curr_x >= W) { + x_shared[y_offset][x_offset] = static_cast(0.f); + } else { + T1 x_tmp = static_cast(x[b][c][curr_y][curr_x]); + x_shared[y_offset][x_offset] = wsilu(x_tmp + __bias); + } + } else if (pixel_idx < BLOCK_SIZE * 4) { + pixel_idx -= BLOCK_SIZE * 2; + const int y_offset = (pixel_idx & 1) * boundary_pos; + const int x_offset = pixel_idx / 2 + 1; + const int curr_y = h + y_offset - 1; + const int curr_x = w + x_offset - 1; + if (curr_y < 0 || curr_x < 0 || curr_y >= H || curr_x >= W) { + x_shared[y_offset][x_offset] = static_cast(0.f); + } else { + T1 x_tmp = static_cast(x[b][c][curr_y][curr_x]); + x_shared[y_offset][x_offset] = wsilu(x_tmp + __bias); + } + } + } + + // load corner 4 pixels + if (t_idx < 4) { + const int y_offset = (t_idx / 2) * boundary_pos; + const int x_offset = (t_idx & 1) * boundary_pos; + const int curr_y = h + y_offset - 1; + const int curr_x = w + x_offset - 1; + if (curr_y < 0 || curr_x < 0 || curr_y >= H || curr_x >= W) { + x_shared[y_offset][x_offset] = static_cast(0.f); + } else { + T1 x_tmp = static_cast(x[b][c][curr_y][curr_x]); + x_shared[y_offset][x_offset] = wsilu(x_tmp + __bias); + } + } + + const int per_y_thread_pix_num = BLOCK_SIZE / THREAD_NUM_Y; + const int per_x_thread_pix_num = BLOCK_SIZE / THREAD_NUM_X; + + for (int t_y = 0; t_y < per_y_thread_pix_num; t_y++) { + for (int t_x = 0; t_x < per_x_thread_pix_num; t_x++) { + const int h_offset = threadIdx.y * per_y_thread_pix_num + t_y + 1; + const int w_offset = threadIdx.x * per_x_thread_pix_num + t_x + 1; + const int curr_y = h + h_offset - 1; + const int curr_x = w + w_offset - 1; + // curr_x and curr_y cannot < 0 + if (curr_y >= H || curr_x >= W) { + x_shared[h_offset][w_offset] = static_cast(0.f); + } else { + T1 x_tmp = static_cast(x[b][c][curr_y][curr_x]); + x_shared[h_offset][w_offset] = wsilu(x_tmp + __bias); + } + } + } + __syncthreads(); + + // calculation + for (int t_y = 0; t_y < per_y_thread_pix_num; t_y++) { + for (int t_x = 0; t_x < per_x_thread_pix_num; t_x++) { + const int h_offset = threadIdx.y * per_y_thread_pix_num + t_y; + const int w_offset = threadIdx.x * per_x_thread_pix_num + t_x; + if (h + h_offset < H && w + w_offset < W) { + T1 r = static_cast(0.f); +#pragma unroll + for (int i = 0; i < 3; i++) { +#pragma unroll + for (int j = 0; j < 3; j++) { + r = multiply_add(__weight[i][j], x_shared[h_offset + i][w_offset + j], r); + } + } + + out[b][c][h + h_offset][w + w_offset] = static_cast(r); + } + } + } +} + +torch::Tensor bias_wsilu_depthwise_conv2d_cuda(const torch::Tensor& x, const torch::Tensor& weight, + const torch::Tensor& bias) +{ + const torch::IntArrayRef x_shape = x.sizes(); + const int B = x_shape[0]; + const int C = x_shape[1]; + const int H = x_shape[2]; + const int W = x_shape[3]; + + auto out = torch::empty_like(x); + + const int BLOCK_SIZE = 32; + const int THREAD_NUM_X = 16; + const int THREAD_NUM_Y = 8; + const dim3 gridDim((W + BLOCK_SIZE - 1) / BLOCK_SIZE, (H + BLOCK_SIZE - 1) / BLOCK_SIZE, B * C); + const dim3 blockDim(THREAD_NUM_X, THREAD_NUM_Y); + auto stream = at::cuda::getCurrentCUDAStream(); + if (x.dtype() == torch::kFloat32) { + bias_wsilu_depthwise_conv2d_kernel + <<>>( + out.packed_accessor32(), + x.packed_accessor32(), + weight.packed_accessor32(), + bias.packed_accessor32(), B, C, H, W); + } else if (x.dtype() == torch::kFloat16) { + bias_wsilu_depthwise_conv2d_kernel + <<>>( + out.packed_accessor32(), + x.packed_accessor32(), + weight.packed_accessor32(), + bias.packed_accessor32(), B, C, H, W); + } + return out; +} diff --git a/DCVC-RT/src/layers/extensions/inference/setup.py b/DCVC-RT/src/layers/extensions/inference/setup.py new file mode 100644 index 0000000..6a25fd5 --- /dev/null +++ b/DCVC-RT/src/layers/extensions/inference/setup.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import glob +import sys +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +cxx_flags = ["-O3"] +nvcc_flags = ["-O3", "--use_fast_math", "--extra-device-vectorization", "-arch=native"] +if sys.platform == 'win32': + cxx_flags = ["/O2"] + + +setup( + name='inference_extensions_cuda', + ext_modules=[ + CUDAExtension( + name='inference_extensions_cuda', + sources=glob.glob('*.cpp') + glob.glob('*.cu'), + extra_compile_args={ + "cxx": cxx_flags, + "nvcc": nvcc_flags, + }, + ), + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/DCVC-RT/src/layers/layers.py b/DCVC-RT/src/layers/layers.py new file mode 100644 index 0000000..a51e93d --- /dev/null +++ b/DCVC-RT/src/layers/layers.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torch import nn +from .cuda_inference import CUSTOMIZED_CUDA_INFERENCE +if CUSTOMIZED_CUDA_INFERENCE: + from .cuda_inference import DepthConvProxy, SubpelConv2xProxy + + +class WSiLU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sigmoid(4.0 * x) * x + + +class WSiLUChunkAdd(nn.Module): + def __init__(self): + super().__init__() + self.silu = WSiLU() + + def forward(self, x): + x1, x2 = self.silu(x).chunk(2, 1) + return x1 + x2 + + +class SubpelConv2x(nn.Module): + def __init__(self, in_ch, out_ch, kernel_size, padding=0): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch * 4, kernel_size=kernel_size, padding=padding), + nn.PixelShuffle(2), + ) + self.padding = padding + + self.proxy = None + + def forward(self, x, to_cat=None, cat_at_front=True): + if not CUSTOMIZED_CUDA_INFERENCE or not x.is_cuda: + return self.forward_torch(x, to_cat, cat_at_front) + + return self.forward_cuda(x, to_cat, cat_at_front) + + def forward_torch(self, x, to_cat=None, cat_at_front=True): + out = self.conv(x) + if to_cat is None: + return out + if cat_at_front: + return torch.cat((to_cat, out), dim=1) + return torch.cat((out, to_cat), dim=1) + + def forward_cuda(self, x, to_cat=None, cat_at_front=True): + if self.proxy is None: + self.proxy = SubpelConv2xProxy() + self.proxy.set_param(self.conv[0].weight, self.conv[0].bias, self.padding) + + if to_cat is None: + return self.proxy.forward(x) + + return self.proxy.forward_with_cat(x, to_cat, cat_at_front) + + +class DepthConvBlock(nn.Module): + def __init__(self, in_ch, out_ch, shortcut=False, force_adaptor=False): + super().__init__() + self.adaptor = None + if in_ch != out_ch or force_adaptor: + self.adaptor = nn.Conv2d(in_ch, out_ch, 1) + self.shortcut = shortcut + self.dc = nn.Sequential( + nn.Conv2d(out_ch, out_ch, 1), + WSiLU(), + nn.Conv2d(out_ch, out_ch, 3, padding=1, groups=out_ch), + nn.Conv2d(out_ch, out_ch, 1), + ) + self.ffn = nn.Sequential( + nn.Conv2d(out_ch, out_ch * 4, 1), + WSiLUChunkAdd(), + nn.Conv2d(out_ch * 2, out_ch, 1), + ) + + self.proxy = None + + def forward(self, x, quant_step=None, to_cat=None, cat_at_front=True): + if not CUSTOMIZED_CUDA_INFERENCE or not x.is_cuda: + return self.forward_torch(x, quant_step, to_cat, cat_at_front) + + return self.forward_cuda(x, quant_step, to_cat, cat_at_front) + + def forward_torch(self, x, quant_step=None, to_cat=None, cat_at_front=True): + if self.adaptor is not None: + x = self.adaptor(x) + out = self.dc(x) + x + out = self.ffn(out) + out + if self.shortcut: + out = out + x + if quant_step is not None: + out = out * quant_step + if to_cat is not None: + if cat_at_front: + out = torch.cat((to_cat, out), dim=1) + else: + out = torch.cat((out, to_cat), dim=1) + return out + + def forward_cuda(self, x, quant_step=None, to_cat=None, cat_at_front=True): + if self.proxy is None: + self.proxy = DepthConvProxy() + if self.adaptor is not None: + self.proxy.set_param_with_adaptor(self.dc[0].weight, self.dc[0].bias, + self.dc[2].weight, self.dc[2].bias, + self.dc[3].weight, self.dc[3].bias, + self.ffn[0].weight, self.ffn[0].bias, + self.ffn[2].weight, self.ffn[2].bias, + self.adaptor.weight, self.adaptor.bias, + self.shortcut) + else: + self.proxy.set_param(self.dc[0].weight, self.dc[0].bias, + self.dc[2].weight, self.dc[2].bias, + self.dc[3].weight, self.dc[3].bias, + self.ffn[0].weight, self.ffn[0].bias, + self.ffn[2].weight, self.ffn[2].bias, + self.shortcut) + + if quant_step is not None: + return self.proxy.forward_with_quant_step(x, quant_step) + if to_cat is not None: + return self.proxy.forward_with_cat(x, to_cat, cat_at_front) + + return self.proxy.forward(x) + + +class ResidualBlockWithStride2(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.down = nn.Conv2d(in_ch, out_ch, 2, stride=2) + self.conv = DepthConvBlock(out_ch, out_ch, shortcut=True) + + def forward(self, x): + x = self.down(x) + out = self.conv(x) + return out + + +class ResidualBlockUpsample(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.up = SubpelConv2x(in_ch, out_ch, 1) + self.conv = DepthConvBlock(out_ch, out_ch, shortcut=True) + + def forward(self, x): + out = self.up(x) + out = self.conv(out) + return out diff --git a/DCVC-RT/src/models/common_model.py b/DCVC-RT/src/models/common_model.py new file mode 100644 index 0000000..dc18b65 --- /dev/null +++ b/DCVC-RT/src/models/common_model.py @@ -0,0 +1,296 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torch import nn + +from ..layers.cuda_inference import combine_for_reading_2x, \ + restore_y_2x, restore_y_2x_with_cat_after, add_and_multiply, \ + replicate_pad, restore_y_4x, clamp_reciprocal_with_quant +from .entropy_models import BitEstimator, GaussianEncoder, EntropyCoder + + +class CompressionModel(nn.Module): + def __init__(self, z_channel, extra_qp=0): + super().__init__() + + self.z_channel = z_channel + self.entropy_coder = None + self.bit_estimator_z = BitEstimator(64 + extra_qp, z_channel) + self.gaussian_encoder = GaussianEncoder() + + self.masks = {} + self.cuda_streams = {} + + def get_cuda_stream(self, device, idx=0, priority=0): + key = f"{device}_{priority}_{idx}" + if key not in self.cuda_streams: + self.cuda_streams[key] = torch.cuda.Stream(device, priority=priority) + return self.cuda_streams[key] + + @staticmethod + def get_qp_num(): + return 64 + + @staticmethod + def get_padding_size(height, width, p=64): + new_h = (height + p - 1) // p * p + new_w = (width + p - 1) // p * p + padding_right = new_w - width + padding_bottom = new_h - height + return padding_right, padding_bottom + + @staticmethod + def get_downsampled_shape(height, width, p): + new_h = (height + p - 1) // p * p + new_w = (width + p - 1) // p * p + return int(new_h / p + 0.5), int(new_w / p + 0.5) + + def update(self, force_zero_thres=None): + self.entropy_coder = EntropyCoder() + self.gaussian_encoder.update(self.entropy_coder, force_zero_thres=force_zero_thres) + self.bit_estimator_z.update(self.entropy_coder) + + def set_use_two_entropy_coders(self, use_two_entropy_coders): + self.entropy_coder.set_use_two_entropy_coders(use_two_entropy_coders) + + def pad_for_y(self, y): + _, _, H, W = y.size() + padding_r, padding_b = self.get_padding_size(H, W, 4) + y_pad = replicate_pad(y, padding_b, padding_r) + return y_pad + + def separate_prior(self, params, is_video=False): + if is_video: + quant_step, scales, means = params.chunk(3, 1) + quant_step = torch.clamp_min(quant_step, 0.5) + q_enc = 1. / quant_step + q_dec = quant_step + else: + q = params[:, :2, :, :] + q_enc, q_dec = (torch.sigmoid(q) * 1.5 + 0.5).chunk(2, 1) + scales, means = params[:, 2:, :, :].chunk(2, 1) + return q_enc, q_dec, scales, means + + @staticmethod + def separate_prior_for_video_encoding(params, y): + q_dec, scales, means = params.chunk(3, 1) + q_dec, y = clamp_reciprocal_with_quant(q_dec, y, 0.5) + return y, q_dec, scales, means + + @staticmethod + def separate_prior_for_video_decoding(params): + quant_step, scales, means = params.chunk(3, 1) + quant_step = torch.clamp_min(quant_step, 0.5) + return quant_step, scales, means + + def process_with_mask(self, y, scales, means, mask): + return self.gaussian_encoder.process_with_mask(y, scales, means, mask) + + @staticmethod + def get_one_mask(micro_mask, height, width, dtype, device): + mask = torch.tensor(micro_mask, dtype=dtype, device=device) + mask = mask.repeat((height + 1) // 2, (width + 1) // 2) + mask = mask[:height, :width] + mask = torch.unsqueeze(mask, 0) + mask = torch.unsqueeze(mask, 0) + return mask + + def get_mask_4x(self, batch, channel, height, width, dtype, device): + curr_mask_str = f"{batch}_{channel}_{width}_{height}_4x" + with torch.no_grad(): + if curr_mask_str not in self.masks: + assert channel % 4 == 0 + m = torch.ones((batch, channel // 4, height, width), dtype=dtype, device=device) + m0 = self.get_one_mask(((1, 0), (0, 0)), height, width, dtype, device) + m1 = self.get_one_mask(((0, 1), (0, 0)), height, width, dtype, device) + m2 = self.get_one_mask(((0, 0), (1, 0)), height, width, dtype, device) + m3 = self.get_one_mask(((0, 0), (0, 1)), height, width, dtype, device) + + mask_0 = torch.cat((m * m0, m * m1, m * m2, m * m3), dim=1) + mask_1 = torch.cat((m * m3, m * m2, m * m1, m * m0), dim=1) + mask_2 = torch.cat((m * m2, m * m3, m * m0, m * m1), dim=1) + mask_3 = torch.cat((m * m1, m * m0, m * m3, m * m2), dim=1) + + self.masks[curr_mask_str] = [mask_0, mask_1, mask_2, mask_3] + return self.masks[curr_mask_str] + + def get_mask_2x(self, batch, channel, height, width, dtype, device): + curr_mask_str = f"{batch}_{channel}_{width}_{height}_2x" + with torch.no_grad(): + if curr_mask_str not in self.masks: + assert channel % 2 == 0 + m = torch.ones((batch, channel // 2, height, width), dtype=dtype, device=device) + m0 = self.get_one_mask(((1, 0), (0, 1)), height, width, dtype, device) + m1 = self.get_one_mask(((0, 1), (1, 0)), height, width, dtype, device) + + mask_0 = torch.cat((m * m0, m * m1), dim=1) + mask_1 = torch.cat((m * m1, m * m0), dim=1) + + self.masks[curr_mask_str] = [mask_0, mask_1] + return self.masks[curr_mask_str] + + @staticmethod + def single_part_for_writing_4x(x): + x0, x1, x2, x3 = x.chunk(4, 1) + return (x0 + x1) + (x2 + x3) + + @staticmethod + def single_part_for_writing_2x(x): + x0, x1 = x.chunk(2, 1) + return x0 + x1 + + def compress_prior_2x(self, y, common_params, y_spatial_prior): + y, q_dec, scales, means = self.separate_prior_for_video_encoding(common_params, y) + dtype = y.dtype + device = y.device + B, C, H, W = y.size() + mask_0, mask_1 = self.get_mask_2x(B, C, H, W, dtype, device) + + _, y_q_0, y_hat_0, s_hat_0 = self.process_with_mask(y, scales, means, mask_0) + cat_params = torch.cat((y_hat_0, common_params), dim=1) + scales, means = y_spatial_prior(cat_params).chunk(2, 1) + _, y_q_1, y_hat_1, s_hat_1 = self.process_with_mask(y, scales, means, mask_1) + + y_hat = add_and_multiply(y_hat_0, y_hat_1, q_dec) + + y_q_w_0 = self.single_part_for_writing_2x(y_q_0) + y_q_w_1 = self.single_part_for_writing_2x(y_q_1) + s_w_0 = self.single_part_for_writing_2x(s_hat_0) + s_w_1 = self.single_part_for_writing_2x(s_hat_1) + return y_q_w_0, y_q_w_1, s_w_0, s_w_1, y_hat + + def decompress_prior_2x(self, common_params, y_spatial_prior): + infos = self.decompress_prior_2x_part1(common_params) + y_hat = self.decompress_prior_2x_part2(common_params, y_spatial_prior, infos) + return y_hat + + def decompress_prior_2x_part1(self, common_params): + q_dec, scales, means = self.separate_prior_for_video_decoding(common_params) + dtype = means.dtype + device = means.device + B, C, H, W = means.size() + mask_0, mask_1 = self.get_mask_2x(B, C, H, W, dtype, device) + + scales_r = combine_for_reading_2x(scales, mask_0, inplace=False) + indexes, skip_cond = self.gaussian_encoder.build_indexes_decoder(scales_r) + self.gaussian_encoder.decode_y(indexes) + infos = { + "q_dec": q_dec, + "mask_0": mask_0, + "mask_1": mask_1, + "means": means, + "scales_r": scales_r, + "skip_cond": skip_cond, + "indexes": indexes, + } + return infos + + def decompress_prior_2x_part2(self, common_params, y_spatial_prior, infos): + dtype = common_params.dtype + device = common_params.device + y_q_r = self.gaussian_encoder.get_y(infos["scales_r"].shape, + infos["scales_r"].numel(), + dtype, device, + infos["skip_cond"], infos["indexes"]) + y_hat_0, cat_params = restore_y_2x_with_cat_after(y_q_r, infos["means"], infos["mask_0"], + common_params) + scales, means = y_spatial_prior(cat_params).chunk(2, 1) + scales_r = combine_for_reading_2x(scales, infos["mask_1"], inplace=True) + y_q_r = self.gaussian_encoder.decode_and_get_y(scales_r, dtype, device) + y_hat_1 = restore_y_2x(y_q_r, means, infos["mask_1"]) + + y_hat = add_and_multiply(y_hat_0, y_hat_1, infos["q_dec"]) + return y_hat + + def compress_prior_4x(self, y, common_params, y_spatial_prior_reduction, + y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2, + y_spatial_prior_adaptor_3, y_spatial_prior): + ''' + y_0 means split in channel, the 0/4 quater + y_1 means split in channel, the 1/4 quater + y_2 means split in channel, the 2/4 quater + y_3 means split in channel, the 3/4 quater + y_?_0, means multiply with mask_0 + y_?_1, means multiply with mask_1 + y_?_2, means multiply with mask_2 + y_?_3, means multiply with mask_3 + ''' + q_enc, q_dec, scales, means = self.separate_prior(common_params, False) + common_params = y_spatial_prior_reduction(common_params) + dtype = y.dtype + device = y.device + B, C, H, W = y.size() + mask_0, mask_1, mask_2, mask_3 = self.get_mask_4x(B, C, H, W, dtype, device) + + y = y * q_enc + + _, y_q_0, y_hat_0, s_hat_0 = self.process_with_mask(y, scales, means, mask_0) + + y_hat_so_far = y_hat_0 + params = torch.cat((y_hat_so_far, common_params), dim=1) + scales, means = y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(2, 1) + _, y_q_1, y_hat_1, s_hat_1 = self.process_with_mask(y, scales, means, mask_1) + + y_hat_so_far = y_hat_so_far + y_hat_1 + params = torch.cat((y_hat_so_far, common_params), dim=1) + scales, means = y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(2, 1) + _, y_q_2, y_hat_2, s_hat_2 = self.process_with_mask(y, scales, means, mask_2) + + y_hat_so_far = y_hat_so_far + y_hat_2 + params = torch.cat((y_hat_so_far, common_params), dim=1) + scales, means = y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(2, 1) + _, y_q_3, y_hat_3, s_hat_3 = self.process_with_mask(y, scales, means, mask_3) + + y_hat = y_hat_so_far + y_hat_3 + y_hat = y_hat * q_dec + + y_q_w_0 = self.single_part_for_writing_4x(y_q_0) + y_q_w_1 = self.single_part_for_writing_4x(y_q_1) + y_q_w_2 = self.single_part_for_writing_4x(y_q_2) + y_q_w_3 = self.single_part_for_writing_4x(y_q_3) + s_w_0 = self.single_part_for_writing_4x(s_hat_0) + s_w_1 = self.single_part_for_writing_4x(s_hat_1) + s_w_2 = self.single_part_for_writing_4x(s_hat_2) + s_w_3 = self.single_part_for_writing_4x(s_hat_3) + return y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, s_w_0, s_w_1, s_w_2, s_w_3, y_hat + + def decompress_prior_4x(self, common_params, y_spatial_prior_reduction, + y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2, + y_spatial_prior_adaptor_3, y_spatial_prior): + _, quant_step, scales, means = self.separate_prior(common_params, False) + common_params = y_spatial_prior_reduction(common_params) + dtype = means.dtype + device = means.device + B, C, H, W = means.size() + mask_0, mask_1, mask_2, mask_3 = self.get_mask_4x(B, C, H, W, dtype, device) + + scales_r = self.single_part_for_writing_4x(scales * mask_0) + y_q_r = self.gaussian_encoder.decode_and_get_y(scales_r, dtype, device) + y_hat_curr_step = restore_y_4x(y_q_r, means, mask_0) + y_hat_so_far = y_hat_curr_step + + params = torch.cat((y_hat_so_far, common_params), dim=1) + scales, means = y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(2, 1) + scales_r = self.single_part_for_writing_4x(scales * mask_1) + y_q_r = self.gaussian_encoder.decode_and_get_y(scales_r, dtype, device) + y_hat_curr_step = restore_y_4x(y_q_r, means, mask_1) + y_hat_so_far = y_hat_so_far + y_hat_curr_step + + params = torch.cat((y_hat_so_far, common_params), dim=1) + scales, means = y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(2, 1) + scales_r = self.single_part_for_writing_4x(scales * mask_2) + y_q_r = self.gaussian_encoder.decode_and_get_y(scales_r, dtype, device) + y_hat_curr_step = restore_y_4x(y_q_r, means, mask_2) + y_hat_so_far = y_hat_so_far + y_hat_curr_step + + params = torch.cat((y_hat_so_far, common_params), dim=1) + scales, means = y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(2, 1) + scales_r = self.single_part_for_writing_4x(scales * mask_3) + y_q_r = self.gaussian_encoder.decode_and_get_y(scales_r, dtype, device) + y_hat_curr_step = restore_y_4x(y_q_r, means, mask_3) + y_hat_so_far = y_hat_so_far + y_hat_curr_step + + y_hat = y_hat_so_far * quant_step + + return y_hat diff --git a/DCVC-RT/src/models/entropy_models.py b/DCVC-RT/src/models/entropy_models.py new file mode 100644 index 0000000..9b64d28 --- /dev/null +++ b/DCVC-RT/src/models/entropy_models.py @@ -0,0 +1,341 @@ +import math + +import torch +import numpy as np +from torch import nn +import torch.nn.functional as F + +from ..layers.cuda_inference import build_index_dec, build_index_enc, process_with_mask + + +class EntropyCoder(): + def __init__(self): + super().__init__() + + from MLCodec_extensions_cpp import RansEncoder, RansDecoder + self.encoder = RansEncoder() + self.decoder = RansDecoder() + + @staticmethod + def pmf_to_quantized_cdf(pmf, precision=16): + from MLCodec_extensions_cpp import pmf_to_quantized_cdf as _pmf_to_cdf + cdf = _pmf_to_cdf(pmf.tolist(), precision) + cdf = torch.IntTensor(cdf) + return cdf + + @staticmethod + def pmf_to_cdf(pmf, tail_mass, pmf_length, max_length): + entropy_coder_precision = 16 + cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) + for i, p in enumerate(pmf): + prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) + _cdf = EntropyCoder.pmf_to_quantized_cdf(prob, entropy_coder_precision) + cdf[i, : _cdf.size(0)] = _cdf + return cdf + + def reset(self): + self.encoder.reset() + + def add_cdf(self, cdf, cdf_length, offset): + enc_cdf_idx = self.encoder.add_cdf(cdf, cdf_length, offset) + dec_cdf_idx = self.decoder.add_cdf(cdf, cdf_length, offset) + assert enc_cdf_idx == dec_cdf_idx + return enc_cdf_idx + + def encode_y(self, symbols, cdf_group_index): + # symbols: int16, high 8 bits: int8 symbol to be encoded; low 8 bits: uint8 index to use + assert symbols.dtype == torch.int16 + self.encoder.encode_y(symbols.cpu().numpy(), cdf_group_index) + + def encode_z(self, symbols, cdf_group_index, start_offset, per_channel_size): + self.encoder.encode_z(symbols.to(torch.int8).cpu().numpy(), + cdf_group_index, start_offset, per_channel_size) + + def flush(self): + self.encoder.flush() + + def get_encoded_stream(self): + return self.encoder.get_encoded_stream().tobytes() + + def set_stream(self, stream): + self.decoder.set_stream((np.frombuffer(stream, dtype=np.uint8))) + + def decode_y(self, indexes, cdf_group_index): + self.decoder.decode_y(indexes.to(torch.uint8).cpu().numpy(), cdf_group_index) + + def decode_and_get_y(self, indexes, cdf_group_index, device, dtype): + rv = self.decoder.decode_and_get_y(indexes.to(torch.uint8).cpu().numpy(), cdf_group_index) + rv = torch.as_tensor(rv) + return rv.to(device).to(dtype) + + def decode_z(self, total_size, cdf_group_index, start_offset, per_channel_size): + self.decoder.decode_z(total_size, cdf_group_index, start_offset, per_channel_size) + + def get_decoded_tensor(self, device, dtype, non_blocking=False): + rv = self.decoder.get_decoded_tensor() + rv = torch.as_tensor(rv) + return rv.to(device, non_blocking=non_blocking).to(dtype) + + def set_use_two_entropy_coders(self, use_two_entropy_coders): + self.encoder.set_use_two_encoders(use_two_entropy_coders) + self.decoder.set_use_two_decoders(use_two_entropy_coders) + + +class Bitparm(nn.Module): + def __init__(self, qp_num, channel, final=False): + super().__init__() + self.final = final + self.h = nn.Parameter(torch.nn.init.normal_( + torch.empty([qp_num, channel, 1, 1]), 0, 0.01)) + self.b = nn.Parameter(torch.nn.init.normal_( + torch.empty([qp_num, channel, 1, 1]), 0, 0.01)) + if not final: + self.a = nn.Parameter(torch.nn.init.normal_( + torch.empty([qp_num, channel, 1, 1]), 0, 0.01)) + else: + self.a = None + + def forward(self, x, index): + h = torch.index_select(self.h, 0, index) + b = torch.index_select(self.b, 0, index) + x = x * F.softplus(h) + b + if self.final: + return x + + a = torch.index_select(self.a, 0, index) + return x + torch.tanh(x) * torch.tanh(a) + + +class AEHelper(): + def __init__(self): + super().__init__() + self.entropy_coder = None + self.cdf_group_index = None + self._offset = None + self._quantized_cdf = None + self._cdf_length = None + + def set_cdf_info(self, quantized_cdf, cdf_length, offset): + self._quantized_cdf = quantized_cdf.cpu().numpy() + self._cdf_length = cdf_length.reshape(-1).int().cpu().numpy() + self._offset = offset.reshape(-1).int().cpu().numpy() + + def get_cdf_info(self): + return self._quantized_cdf, \ + self._cdf_length, \ + self._offset + + +class BitEstimator(AEHelper, nn.Module): + def __init__(self, qp_num, channel): + super().__init__() + self.f1 = Bitparm(qp_num, channel) + self.f2 = Bitparm(qp_num, channel) + self.f3 = Bitparm(qp_num, channel) + self.f4 = Bitparm(qp_num, channel, True) + self.qp_num = qp_num + self.channel = channel + + def forward(self, x, index): + return self.get_cdf(x, index) + + def get_logits_cdf(self, x, index): + x = self.f1(x, index) + x = self.f2(x, index) + x = self.f3(x, index) + x = self.f4(x, index) + return x + + def get_cdf(self, x, index): + return torch.sigmoid(self.get_logits_cdf(x, index)) + + def update(self, entropy_coder): + self.entropy_coder = entropy_coder + + with torch.no_grad(): + device = next(self.parameters()).device + medians = torch.zeros((self.qp_num, self.channel, 1, 1), device=device) + index = torch.arange(self.qp_num, device=device, dtype=torch.int32) + + minima = medians + 8 + for i in range(8, 1, -1): + samples = torch.zeros_like(medians) - i + probs = self.forward(samples, index) + minima = torch.where(probs < torch.zeros_like(medians) + 0.0001, + torch.zeros_like(medians) + i, minima) + + maxima = medians + 8 + for i in range(8, 1, -1): + samples = torch.zeros_like(medians) + i + probs = self.forward(samples, index) + maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999, + torch.zeros_like(medians) + i, maxima) + + minima = minima.int() + maxima = maxima.int() + + offset = -minima + + pmf_start = medians - minima + pmf_length = maxima + minima + 1 + + max_length = pmf_length.max() + device = pmf_start.device + samples = torch.arange(max_length, device=device) + + samples = samples[None, None, None, :] + pmf_start + + half = float(0.5) + + lower = self.forward(samples - half, index) + upper = self.forward(samples + half, index) + pmf = upper - lower + + pmf = pmf[:, :, 0, :] + upper = self.forward(maxima.to(torch.float32), index) + tail_mass = lower[:, :, 0, :1] + (1.0 - upper[:, :, 0, -1:]) + + pmf = pmf.reshape([-1, max_length]) + tail_mass = tail_mass.reshape([-1, 1]) + pmf_length = pmf_length.reshape([-1]) + offset = offset.reshape([-1]) + quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + cdf_length = pmf_length + 2 + self.set_cdf_info(quantized_cdf, cdf_length, offset) + self.cdf_group_index = self.entropy_coder.add_cdf(*self.get_cdf_info()) + + def build_indexes(self, size, qp): + B, C, H, W = size + indexes = torch.arange(C, dtype=torch.int).view(1, -1, 1, 1) + qp * self.channel + return indexes.repeat(B, 1, H, W) + + def encode_z(self, x, qp): + _, _, H, W = x.size() + return self.entropy_coder.encode_z(x.reshape(-1), self.cdf_group_index, qp * self.channel, + H * W) + + def decode_z(self, size, qp): + self.entropy_coder.decode_z(self.channel * size[0] * size[1], self.cdf_group_index, + qp * self.channel, size[0] * size[1]) + + def get_z(self, size, device, dtype): + output_size = (1, self.channel, size[0], size[1]) + val = self.entropy_coder.get_decoded_tensor(device, dtype, non_blocking=True) + return val.reshape(output_size) + + +class GaussianEncoder(AEHelper): + def __init__(self): + super().__init__() + self.scale_min = 0.11 + self.scale_max = 16.0 + self.scale_level = 128 # <= 256 + self.scale_table = self.get_scale_table(self.scale_min, self.scale_max, self.scale_level) + + self.log_scale_min = math.log(self.scale_min) + self.log_scale_max = math.log(self.scale_max) + self.log_scale_step = (self.log_scale_max - self.log_scale_min) / (self.scale_level - 1) + self.log_step_recip = 1. / self.log_scale_step + + self.force_zero_thres = None + self.decode_index_cache = {} + self.decode_zeros_cache = {} + + @staticmethod + def get_scale_table(min_val, max_val, levels): + return torch.exp(torch.linspace(math.log(min_val), math.log(max_val), levels)) + + def update(self, entropy_coder, force_zero_thres=None): + self.entropy_coder = entropy_coder + self.force_zero_thres = force_zero_thres + + pmf_center = torch.zeros_like(self.scale_table) + 8 + scales = torch.zeros_like(pmf_center) + self.scale_table + cdf_distribution = torch.distributions.normal.Normal(0., scales) + for i in range(8, 1, -1): + samples = torch.zeros_like(pmf_center) + i + probs = cdf_distribution.cdf(samples) + probs = torch.squeeze(probs) + pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999, + torch.zeros_like(pmf_center) + i, pmf_center) + + pmf_center = pmf_center.int() + pmf_length = 2 * pmf_center + 1 + max_length = torch.max(pmf_length).item() + + device = pmf_center.device + samples = torch.arange(max_length, device=device) - pmf_center[:, None] + samples = samples.float() + + scales = torch.zeros_like(samples) + self.scale_table[:, None] + cdf_distribution = torch.distributions.normal.Normal(0., scales) + + upper = cdf_distribution.cdf(samples + 0.5) + lower = cdf_distribution.cdf(samples - 0.5) + pmf = upper - lower + + tail_mass = 2 * lower[:, :1] + + quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) + quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) + + self.set_cdf_info(quantized_cdf, pmf_length+2, -pmf_center) + self.cdf_group_index = self.entropy_coder.add_cdf(*self.get_cdf_info()) + + def process_with_mask(self, y, scales, means, mask): + return process_with_mask(y, scales, means, mask, self.force_zero_thres) + + def build_indexes_decoder(self, scales): + scales = scales.reshape(-1) + indexes, skip_cond = build_index_dec(scales, self.scale_min, self.scale_max, + self.log_scale_min, self.log_step_recip, + self.force_zero_thres) + if self.force_zero_thres is not None: + indexes = indexes[skip_cond] + return indexes, skip_cond + + def build_indexes_encoder(self, symbols, scales): + symbols = symbols.reshape(-1) + scales = scales.reshape(-1) + symbols = build_index_enc(symbols, scales, self.scale_min, self.scale_max, + self.log_scale_min, self.log_step_recip, self.force_zero_thres) + return symbols + + def encode_y(self, x, scales): + symbols = self.build_indexes_encoder(x, scales) + return self.entropy_coder.encode_y(symbols, self.cdf_group_index) + + def get_decode_index_cache(self, num, device): + if num not in self.decode_index_cache: + c = torch.arange(0, num, dtype=torch.int32, device=device) + self.decode_index_cache[num] = c + + return self.decode_index_cache[num] + + def get_decode_zeros_cache(self, num, device): + if num not in self.decode_zeros_cache: + c = torch.zeros(num, dtype=torch.int32, device=device) + self.decode_zeros_cache[num] = c + + return self.decode_zeros_cache[num].clone() + + def decode_and_get_y(self, scales, dtype, device): + indexes, skip_cond = self.build_indexes_decoder(scales) + self.decode_y(indexes) + return self.get_y(scales.shape, scales.numel(), dtype, device, skip_cond, indexes) + + def decode_y(self, indexes): + self.entropy_coder.decode_y(indexes, self.cdf_group_index) + + def get_y(self, shape, numel, dtype, device, skip_cond, indexes): + if len(indexes) == 0: + return torch.zeros(shape, dtype=dtype, device=device) + if skip_cond is not None: + curr_index = self.get_decode_index_cache(numel, device) + back_index = self.get_decode_zeros_cache(numel, device) + back_index.masked_scatter_(skip_cond, curr_index) + val = self.entropy_coder.get_decoded_tensor(device, dtype, non_blocking=True) + if skip_cond is not None: + y = torch.index_select(val, 0, back_index) * skip_cond + return y.reshape(shape) + return val.reshape(shape) diff --git a/DCVC-RT/src/models/image_model.py b/DCVC-RT/src/models/image_model.py new file mode 100644 index 0000000..fa1ea67 --- /dev/null +++ b/DCVC-RT/src/models/image_model.py @@ -0,0 +1,209 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torch import nn +import torch.nn.functional as F + + +from .common_model import CompressionModel +from ..layers.layers import DepthConvBlock, ResidualBlockUpsample, ResidualBlockWithStride2 +from ..layers.cuda_inference import CUSTOMIZED_CUDA_INFERENCE, round_and_to_int8 + +g_ch_src = 3 * 8 * 8 +g_ch_enc_dec = 368 + + +class IntraEncoder(nn.Module): + def __init__(self, N): + super().__init__() + + self.enc_1 = DepthConvBlock(g_ch_src, g_ch_enc_dec) + self.enc_2 = nn.Sequential( + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + nn.Conv2d(g_ch_enc_dec, N, 3, stride=2, padding=1), + ) + + def forward(self, x, quant_step): + out = F.pixel_unshuffle(x, 8) + if not CUSTOMIZED_CUDA_INFERENCE or not x.is_cuda: + return self.forward_torch(out, quant_step) + + return self.forward_cuda(out, quant_step) + + def forward_torch(self, out, quant_step): + out = self.enc_1(out) + out = out * quant_step + return self.enc_2(out) + + def forward_cuda(self, out, quant_step): + out = self.enc_1(out, quant_step=quant_step) + return self.enc_2(out) + + +class IntraDecoder(nn.Module): + def __init__(self, N): + super().__init__() + + self.dec_1 = nn.Sequential( + ResidualBlockUpsample(N, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + DepthConvBlock(g_ch_enc_dec, g_ch_enc_dec), + ) + self.dec_2 = DepthConvBlock(g_ch_enc_dec, g_ch_src) + + def forward(self, x, quant_step): + if not CUSTOMIZED_CUDA_INFERENCE or not x.is_cuda: + return self.forward_torch(x, quant_step) + + return self.forward_cuda(x, quant_step) + + def forward_torch(self, x, quant_step): + out = self.dec_1(x) + out = out * quant_step + out = self.dec_2(out) + out = F.pixel_shuffle(out, 8) + return out + + def forward_cuda(self, x, quant_step): + out = self.dec_1[0](x) + out = self.dec_1[1](out) + out = self.dec_1[2](out) + out = self.dec_1[3](out) + out = self.dec_1[4](out) + out = self.dec_1[5](out) + out = self.dec_1[6](out) + out = self.dec_1[7](out) + out = self.dec_1[8](out) + out = self.dec_1[9](out) + out = self.dec_1[10](out) + out = self.dec_1[11](out) + out = self.dec_1[12](out, quant_step=quant_step) + out = self.dec_2(out) + out = F.pixel_shuffle(out, 8) + return out + + +class DMCI(CompressionModel): + def __init__(self, N=256, z_channel=128): + super().__init__(z_channel=z_channel) + + self.enc = IntraEncoder(N) + + self.hyper_enc = nn.Sequential( + DepthConvBlock(N, z_channel), + ResidualBlockWithStride2(z_channel, z_channel), + ResidualBlockWithStride2(z_channel, z_channel), + ) + + self.hyper_dec = nn.Sequential( + ResidualBlockUpsample(z_channel, z_channel), + ResidualBlockUpsample(z_channel, z_channel), + DepthConvBlock(z_channel, N), + ) + + self.y_prior_fusion = nn.Sequential( + DepthConvBlock(N, N * 2), + DepthConvBlock(N * 2, N * 2), + DepthConvBlock(N * 2, N * 2), + nn.Conv2d(N * 2, N * 2 + 2, 1), + ) + + self.y_spatial_prior_reduction = nn.Conv2d(N * 2 + 2, N * 1, 1) + self.y_spatial_prior_adaptor_1 = DepthConvBlock(N * 2, N * 2, force_adaptor=True) + self.y_spatial_prior_adaptor_2 = DepthConvBlock(N * 2, N * 2, force_adaptor=True) + self.y_spatial_prior_adaptor_3 = DepthConvBlock(N * 2, N * 2, force_adaptor=True) + self.y_spatial_prior = nn.Sequential( + DepthConvBlock(N * 2, N * 2), + DepthConvBlock(N * 2, N * 2), + DepthConvBlock(N * 2, N * 2), + nn.Conv2d(N * 2, N * 2, 1), + ) + + self.dec = IntraDecoder(N) + + self.q_scale_enc = nn.Parameter(torch.ones((self.get_qp_num(), g_ch_enc_dec, 1, 1))) + self.q_scale_dec = nn.Parameter(torch.ones((self.get_qp_num(), g_ch_enc_dec, 1, 1))) + + def compress(self, x, qp): + device = x.device + curr_q_enc = self.q_scale_enc[qp:qp+1, :, :, :] + curr_q_dec = self.q_scale_dec[qp:qp+1, :, :, :] + + y = self.enc(x, curr_q_enc) + y_pad = self.pad_for_y(y) + z = self.hyper_enc(y_pad) + z_hat, z_hat_write = round_and_to_int8(z) + + params = self.hyper_dec(z_hat) + params = self.y_prior_fusion(params) + _, _, yH, yW = y.shape + params = params[:, :, :yH, :yW].contiguous() + y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, s_w_0, s_w_1, s_w_2, s_w_3, y_hat = \ + self.compress_prior_4x( + y, params, self.y_spatial_prior_reduction, + self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2, + self.y_spatial_prior_adaptor_3, self.y_spatial_prior) + + cuda_event = torch.cuda.Event() + cuda_event.record() + x_hat = self.dec(y_hat, curr_q_dec).clamp_(0, 1) + + cuda_stream = self.get_cuda_stream(device=device, priority=-1) + with torch.cuda.stream(cuda_stream): + cuda_event.wait() + self.entropy_coder.reset() + self.bit_estimator_z.encode_z(z_hat_write, qp) + self.gaussian_encoder.encode_y(y_q_w_0, s_w_0) + self.gaussian_encoder.encode_y(y_q_w_1, s_w_1) + self.gaussian_encoder.encode_y(y_q_w_2, s_w_2) + self.gaussian_encoder.encode_y(y_q_w_3, s_w_3) + self.entropy_coder.flush() + + bit_stream = self.entropy_coder.get_encoded_stream() + + torch.cuda.synchronize(device=device) + result = { + "bit_stream": bit_stream, + "x_hat": x_hat, + } + return result + + def decompress(self, bit_stream, sps, qp): + dtype = next(self.parameters()).dtype + device = next(self.parameters()).device + curr_q_dec = self.q_scale_dec[qp:qp+1, :, :, :] + + self.entropy_coder.set_use_two_entropy_coders(sps['ec_part'] == 1) + self.entropy_coder.set_stream(bit_stream) + z_size = self.get_downsampled_shape(sps['height'], sps['width'], 64) + y_height, y_width = self.get_downsampled_shape(sps['height'], sps['width'], 16) + self.bit_estimator_z.decode_z(z_size, qp) + z_q = self.bit_estimator_z.get_z(z_size, device, dtype) + z_hat = z_q + + params = self.hyper_dec(z_hat) + params = self.y_prior_fusion(params) + params = params[:, :, :y_height, :y_width].contiguous() + y_hat = self.decompress_prior_4x(params, self.y_spatial_prior_reduction, + self.y_spatial_prior_adaptor_1, + self.y_spatial_prior_adaptor_2, + self.y_spatial_prior_adaptor_3, self.y_spatial_prior) + + x_hat = self.dec(y_hat, curr_q_dec).clamp_(0, 1) + return {"x_hat": x_hat} diff --git a/DCVC-RT/src/models/video_model.py b/DCVC-RT/src/models/video_model.py new file mode 100644 index 0000000..a770114 --- /dev/null +++ b/DCVC-RT/src/models/video_model.py @@ -0,0 +1,379 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.functional as F +from torch import nn + +from .common_model import CompressionModel +from ..layers.layers import SubpelConv2x, DepthConvBlock, \ + ResidualBlockUpsample, ResidualBlockWithStride2 +from ..layers.cuda_inference import CUSTOMIZED_CUDA_INFERENCE, round_and_to_int8, \ + bias_pixel_shuffle_8, bias_quant + + +qp_shift = [0, 8, 4] +extra_qp = max(qp_shift) + +g_ch_src_d = 3 * 8 * 8 +g_ch_recon = 320 +g_ch_y = 128 +g_ch_z = 128 +g_ch_d = 256 + + +class FeatureExtractor(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Sequential( + DepthConvBlock(g_ch_d, g_ch_d), + DepthConvBlock(g_ch_d, g_ch_d), + ) + self.conv2 = nn.Sequential( + DepthConvBlock(g_ch_d, g_ch_d), + DepthConvBlock(g_ch_d, g_ch_d), + DepthConvBlock(g_ch_d, g_ch_d), + DepthConvBlock(g_ch_d, g_ch_d), + ) + + def forward(self, x, quant): + x1, ctx_t = self.forward_part1(x, quant) + ctx = self.forward_part2(x1) + return ctx, ctx_t + + def forward_part1(self, x, quant): + x1 = self.conv1(x) + ctx_t = x1 * quant + return x1, ctx_t + + def forward_part2(self, x1): + ctx = self.conv2(x1) + return ctx + + +class Encoder(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(g_ch_src_d, g_ch_d, 1) + self.conv2 = nn.Sequential( + DepthConvBlock(g_ch_d * 2, g_ch_d), + DepthConvBlock(g_ch_d, g_ch_d), + ) + self.conv3 = DepthConvBlock(g_ch_d, g_ch_d) + self.down = nn.Conv2d(g_ch_d, g_ch_y, 3, stride=2, padding=1) + + self.fuse_conv1_flag = False + + def forward(self, x, ctx, quant_step): + feature = F.pixel_unshuffle(x, 8) + if not CUSTOMIZED_CUDA_INFERENCE or not x.is_cuda: + return self.forward_torch(feature, ctx, quant_step) + return self.forward_cuda(feature, ctx, quant_step) + + def forward_torch(self, feature, ctx, quant_step): + feature = self.conv1(feature) + feature = self.conv2(torch.cat((feature, ctx), dim=1)) + feature = self.conv3(feature) + feature = feature * quant_step + feature = self.down(feature) + return feature + + def forward_cuda(self, feature, ctx, quant_step): + if not self.fuse_conv1_flag: + fuse_weight1 = torch.matmul( + self.conv2[0].adaptor.weight[:, :g_ch_d, 0, 0], + self.conv1.weight[:, :, 0, 0] + )[:, :, None, None] + fuse_weight2 = self.conv2[0].adaptor.weight[:, g_ch_d:] + self.conv2[0].adaptor.bias.data = self.conv2[0].adaptor.bias + \ + torch.matmul(self.conv2[0].adaptor.weight[:, :g_ch_d, 0, 0], + self.conv1.bias[:, None])[:, 0] + self.conv2[0].adaptor.weight.data = torch.cat([fuse_weight1, fuse_weight2], dim=1) + self.fuse_conv1_flag = True + + feature = self.conv2(torch.cat((feature, ctx), dim=1)) + feature = self.conv3(feature, quant_step=quant_step) + feature = self.down(feature) + return feature + + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + self.up = SubpelConv2x(g_ch_y, g_ch_d, 3, padding=1) + self.conv1 = nn.Sequential( + DepthConvBlock(g_ch_d * 2, g_ch_d), + DepthConvBlock(g_ch_d, g_ch_d), + DepthConvBlock(g_ch_d, g_ch_d), + ) + self.conv2 = nn.Conv2d(g_ch_d, g_ch_d, 1) + + def forward(self, x, ctx, quant_step,): + if not CUSTOMIZED_CUDA_INFERENCE or not x.is_cuda: + return self.forward_torch(x, ctx, quant_step) + + return self.forward_cuda(x, ctx, quant_step) + + def forward_torch(self, x, ctx, quant_step): + feature = self.up(x) + feature = self.conv1(torch.cat((feature, ctx), dim=1)) + feature = self.conv2(feature) + feature = feature * quant_step + return feature + + def forward_cuda(self, x, ctx, quant_step): + feature = self.up(x, to_cat=ctx, cat_at_front=False) + feature = self.conv1(feature) + feature = F.conv2d(feature, self.conv2.weight) + feature = bias_quant(feature, self.conv2.bias, quant_step) + return feature + + +class ReconGeneration(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + DepthConvBlock(g_ch_d, g_ch_recon), + DepthConvBlock(g_ch_recon, g_ch_recon), + DepthConvBlock(g_ch_recon, g_ch_recon), + DepthConvBlock(g_ch_recon, g_ch_recon), + ) + self.head = nn.Conv2d(g_ch_recon, g_ch_src_d, 1) + + def forward(self, x, quant_step): + if not CUSTOMIZED_CUDA_INFERENCE or not x.is_cuda: + return self.forward_torch(x, quant_step) + + return self.forward_cuda(x, quant_step) + + def forward_torch(self, x, quant_step): + out = self.conv(x) + out = out * quant_step + out = self.head(out) + out = F.pixel_shuffle(out, 8) + out = torch.clamp(out, 0., 1.) + return out + + def forward_cuda(self, x, quant_step): + out = self.conv[0](x) + out = self.conv[1](out) + out = self.conv[2](out) + out = self.conv[3](out, quant_step=quant_step) + out = F.conv2d(out, self.head.weight) + return bias_pixel_shuffle_8(out, self.head.bias) + + +class HyperEncoder(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + DepthConvBlock(g_ch_y, g_ch_z), + ResidualBlockWithStride2(g_ch_z, g_ch_z), + ResidualBlockWithStride2(g_ch_z, g_ch_z), + ) + + def forward(self, x): + return self.conv(x) + + +class HyperDecoder(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + ResidualBlockUpsample(g_ch_z, g_ch_z), + ResidualBlockUpsample(g_ch_z, g_ch_z), + DepthConvBlock(g_ch_z, g_ch_y), + ) + + def forward(self, x): + return self.conv(x) + + +class PriorFusion(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + DepthConvBlock(g_ch_y * 3, g_ch_y * 3), + DepthConvBlock(g_ch_y * 3, g_ch_y * 3), + DepthConvBlock(g_ch_y * 3, g_ch_y * 3), + nn.Conv2d(g_ch_y * 3, g_ch_y * 3, 1), + ) + + def forward(self, x): + return self.conv(x) + + +class SpatialPrior(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + DepthConvBlock(g_ch_y * 4, g_ch_y * 3), + DepthConvBlock(g_ch_y * 3, g_ch_y * 3), + nn.Conv2d(g_ch_y * 3, g_ch_y * 2, 1), + ) + + def forward(self, x): + return self.conv(x) + + +class RefFrame(): + def __init__(self): + self.frame = None + self.feature = None + self.poc = None + + +class DMC(CompressionModel): + def __init__(self): + super().__init__(z_channel=g_ch_z, extra_qp=extra_qp) + self.qp_shift = qp_shift + + self.feature_adaptor_i = DepthConvBlock(g_ch_src_d, g_ch_d) + self.feature_adaptor_p = nn.Conv2d(g_ch_d, g_ch_d, 1) + self.feature_extractor = FeatureExtractor() + + self.encoder = Encoder() + self.hyper_encoder = HyperEncoder() + self.hyper_decoder = HyperDecoder() + self.temporal_prior_encoder = ResidualBlockWithStride2(g_ch_d, g_ch_y * 2) + self.y_prior_fusion = PriorFusion() + self.y_spatial_prior = SpatialPrior() + self.decoder = Decoder() + self.recon_generation_net = ReconGeneration() + + self.q_encoder = nn.Parameter(torch.ones((self.get_qp_num() + extra_qp, g_ch_d, 1, 1))) + self.q_decoder = nn.Parameter(torch.ones((self.get_qp_num() + extra_qp, g_ch_d, 1, 1))) + self.q_feature = nn.Parameter(torch.ones((self.get_qp_num() + extra_qp, g_ch_d, 1, 1))) + self.q_recon = nn.Parameter(torch.ones((self.get_qp_num() + extra_qp, g_ch_recon, 1, 1))) + + self.dpb = [] + self.max_dpb_size = 1 + self.curr_poc = 0 + + def reset_ref_feature(self): + if len(self.dpb) > 0: + self.dpb[0].feature = None + + def add_ref_frame(self, feature=None, frame=None, increase_poc=True): + ref_frame = RefFrame() + ref_frame.poc = self.curr_poc + ref_frame.frame = frame + ref_frame.feature = feature + if len(self.dpb) >= self.max_dpb_size: + self.dpb.pop(-1) + self.dpb.insert(0, ref_frame) + if increase_poc: + self.curr_poc += 1 + + def clear_dpb(self): + self.dpb.clear() + + def set_curr_poc(self, poc): + self.curr_poc = poc + + def apply_feature_adaptor(self): + if self.dpb[0].feature is None: + return self.feature_adaptor_i(F.pixel_unshuffle(self.dpb[0].frame, 8)) + return self.feature_adaptor_p(self.dpb[0].feature) + + def res_prior_param_decoder(self, z_hat, ctx_t): + hierarchical_params = self.hyper_decoder(z_hat) + temporal_params = self.temporal_prior_encoder(ctx_t) + _, _, H, W = temporal_params.shape + hierarchical_params = hierarchical_params[:, :, :H, :W].contiguous() + params = self.y_prior_fusion( + torch.cat((hierarchical_params, temporal_params), dim=1)) + return params + + def get_recon_and_feature(self, y_hat, ctx, q_decoder, q_recon): + feature = self.decoder(y_hat, ctx, q_decoder) + x_hat = self.recon_generation_net(feature, q_recon) + return x_hat, feature + + def prepare_feature_adaptor_i(self, last_qp): + if self.dpb[0].frame is None: + q_recon = self.q_recon[last_qp:last_qp+1, :, :, :] + self.dpb[0].frame = self.recon_generation_net(self.dpb[0].feature, q_recon).clamp_(0, 1) + self.reset_ref_feature() + + def compress(self, x, qp): + # pic_width and pic_height may be different from x's size. x here is after padding + # x_hat has the same size with x + device = x.device + q_encoder = self.q_encoder[qp:qp+1, :, :, :] + q_decoder = self.q_decoder[qp:qp+1, :, :, :] + q_feature = self.q_feature[qp:qp+1, :, :, :] + + feature = self.apply_feature_adaptor() + ctx, ctx_t = self.feature_extractor(feature, q_feature) + y = self.encoder(x, ctx, q_encoder) + + hyper_inp = self.pad_for_y(y) + + z = self.hyper_encoder(hyper_inp) + z_hat, z_hat_write = round_and_to_int8(z) + cuda_event_z_ready = torch.cuda.Event() + cuda_event_z_ready.record() + params = self.res_prior_param_decoder(z_hat, ctx_t) + y_q_w_0, y_q_w_1, s_w_0, s_w_1, y_hat = \ + self.compress_prior_2x(y, params, self.y_spatial_prior) + + cuda_event_y_ready = torch.cuda.Event() + cuda_event_y_ready.record() + feature = self.decoder(y_hat, ctx, q_decoder) + + cuda_stream = self.get_cuda_stream(device=device, priority=-1) + with torch.cuda.stream(cuda_stream): + self.entropy_coder.reset() + cuda_event_z_ready.wait() + self.bit_estimator_z.encode_z(z_hat_write, qp) + cuda_event_y_ready.wait() + self.gaussian_encoder.encode_y(y_q_w_0, s_w_0) + self.gaussian_encoder.encode_y(y_q_w_1, s_w_1) + self.entropy_coder.flush() + + bit_stream = self.entropy_coder.get_encoded_stream() + + torch.cuda.synchronize(device=device) + self.add_ref_frame(feature, None) + return { + 'bit_stream': bit_stream, + } + + def decompress(self, bit_stream, sps, qp): + dtype = next(self.parameters()).dtype + device = next(self.parameters()).device + q_decoder = self.q_decoder[qp:qp+1, :, :, :] + q_feature = self.q_feature[qp:qp+1, :, :, :] + q_recon = self.q_recon[qp:qp+1, :, :, :] + + self.entropy_coder.set_use_two_entropy_coders(sps['ec_part'] == 1) + self.entropy_coder.set_stream(bit_stream) + z_size = self.get_downsampled_shape(sps['height'], sps['width'], 64) + self.bit_estimator_z.decode_z(z_size, qp) + + feature = self.apply_feature_adaptor() + c1, ctx_t = self.feature_extractor.forward_part1(feature, q_feature) + + z_hat = self.bit_estimator_z.get_z(z_size, device, dtype) + params = self.res_prior_param_decoder(z_hat, ctx_t) + infos = self.decompress_prior_2x_part1(params) + + ctx = self.feature_extractor.forward_part2(c1) + + cuda_stream = self.get_cuda_stream(device=device, priority=-1) + with torch.cuda.stream(cuda_stream): + y_hat = self.decompress_prior_2x_part2(params, self.y_spatial_prior, infos) + cuda_event = torch.cuda.Event() + cuda_event.record() + + cuda_event.wait() + x_hat, feature = self.get_recon_and_feature(y_hat, ctx, q_decoder, q_recon) + + self.add_ref_frame(feature, x_hat) + return { + 'x_hat': x_hat, + } + + def shift_qp(self, qp, fa_idx): + return qp + self.qp_shift[fa_idx] diff --git a/DCVC-RT/src/utils/common.py b/DCVC-RT/src/utils/common.py new file mode 100644 index 0000000..a7599c0 --- /dev/null +++ b/DCVC-RT/src/utils/common.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import json +import os +from unittest.mock import patch + +import torch +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present +import numpy as np + + +def str2bool(v): + return str(v).lower() in ("yes", "y", "true", "t", "1") + + +def set_torch_env(): + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + torch.manual_seed(0) + torch.set_num_threads(1) + np.random.seed(seed=0) + try: + # require pytorch >= 2.2.0 + torch.utils.deterministic.fill_uninitialized_memory = False + except Exception: # pylint: disable=W0718 + pass + + +def create_folder(path, print_if_create=False): + if not os.path.exists(path): + os.makedirs(path) + if print_if_create: + print(f"created folder: {path}") + + +def get_state_dict(ckpt_path): + ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'), weights_only=True) + if "state_dict" in ckpt: + ckpt = ckpt['state_dict'] + if "net" in ckpt: + ckpt = ckpt["net"] + consume_prefix_in_state_dict_if_present(ckpt, prefix="module.") + return ckpt + + +@patch('json.encoder.c_make_encoder', None) +def dump_json(obj, fid, float_digits=-1, **kwargs): + of = json.encoder._make_iterencode # pylint: disable=W0212 + + def inner(*args, **kwargs): + args = list(args) + # fifth argument is float formater which we will replace + args[4] = lambda o: format(o, '.%df' % float_digits) + return of(*args, **kwargs) + + with patch('json.encoder._make_iterencode', wraps=inner): + json.dump(obj, fid, **kwargs) + + +def generate_log_json(frame_num, frame_pixel_num, test_time, frame_types, bits, psnrs, ssims, + verbose=False, avg_encoding_time=None, avg_decoding_time=None): + include_yuv = len(psnrs[0]) > 1 + assert not include_yuv or (len(psnrs[0]) == 4 and len(ssims[0]) == 4) + i_bits = 0 + i_psnr = 0 + i_psnr_y = 0 + i_psnr_u = 0 + i_psnr_v = 0 + i_ssim = 0 + i_ssim_y = 0 + i_ssim_u = 0 + i_ssim_v = 0 + p_bits = 0 + p_psnr = 0 + p_psnr_y = 0 + p_psnr_u = 0 + p_psnr_v = 0 + p_ssim = 0 + p_ssim_y = 0 + p_ssim_u = 0 + p_ssim_v = 0 + i_num = 0 + p_num = 0 + for idx in range(frame_num): + if frame_types[idx] == 0: + i_bits += bits[idx] + i_psnr += psnrs[idx][0] + i_ssim += ssims[idx][0] + i_num += 1 + if include_yuv: + i_psnr_y += psnrs[idx][1] + i_psnr_u += psnrs[idx][2] + i_psnr_v += psnrs[idx][3] + i_ssim_y += ssims[idx][1] + i_ssim_u += ssims[idx][2] + i_ssim_v += ssims[idx][3] + else: + p_bits += bits[idx] + p_psnr += psnrs[idx][0] + p_ssim += ssims[idx][0] + p_num += 1 + if include_yuv: + p_psnr_y += psnrs[idx][1] + p_psnr_u += psnrs[idx][2] + p_psnr_v += psnrs[idx][3] + p_ssim_y += ssims[idx][1] + p_ssim_u += ssims[idx][2] + p_ssim_v += ssims[idx][3] + + log_result = {} + log_result['frame_pixel_num'] = frame_pixel_num + log_result['i_frame_num'] = i_num + log_result['p_frame_num'] = p_num + log_result['ave_i_frame_bpp'] = i_bits / i_num / frame_pixel_num + log_result['ave_i_frame_psnr'] = i_psnr / i_num + log_result['ave_i_frame_msssim'] = i_ssim / i_num + if include_yuv: + log_result['ave_i_frame_psnr_y'] = i_psnr_y / i_num + log_result['ave_i_frame_psnr_u'] = i_psnr_u / i_num + log_result['ave_i_frame_psnr_v'] = i_psnr_v / i_num + log_result['ave_i_frame_msssim_y'] = i_ssim_y / i_num + log_result['ave_i_frame_msssim_u'] = i_ssim_u / i_num + log_result['ave_i_frame_msssim_v'] = i_ssim_v / i_num + if verbose: + log_result['frame_bpp'] = list(np.array(bits) / frame_pixel_num) + log_result['frame_psnr'] = [v[0] for v in psnrs] + log_result['frame_msssim'] = [v[0] for v in ssims] + log_result['frame_type'] = frame_types + if include_yuv: + log_result['frame_psnr_y'] = [v[1] for v in psnrs] + log_result['frame_psnr_u'] = [v[2] for v in psnrs] + log_result['frame_psnr_v'] = [v[3] for v in psnrs] + log_result['frame_msssim_y'] = [v[1] for v in ssims] + log_result['frame_msssim_u'] = [v[2] for v in ssims] + log_result['frame_msssim_v'] = [v[3] for v in ssims] + log_result['test_time'] = test_time + if p_num > 0: + total_p_pixel_num = p_num * frame_pixel_num + log_result['ave_p_frame_bpp'] = p_bits / total_p_pixel_num + log_result['ave_p_frame_psnr'] = p_psnr / p_num + log_result['ave_p_frame_msssim'] = p_ssim / p_num + if include_yuv: + log_result['ave_p_frame_psnr_y'] = p_psnr_y / p_num + log_result['ave_p_frame_psnr_u'] = p_psnr_u / p_num + log_result['ave_p_frame_psnr_v'] = p_psnr_v / p_num + log_result['ave_p_frame_msssim_y'] = p_ssim_y / p_num + log_result['ave_p_frame_msssim_u'] = p_ssim_u / p_num + log_result['ave_p_frame_msssim_v'] = p_ssim_v / p_num + else: + log_result['ave_p_frame_bpp'] = 0 + log_result['ave_p_frame_psnr'] = 0 + log_result['ave_p_frame_msssim'] = 0 + if include_yuv: + log_result['ave_p_frame_psnr_y'] = 0 + log_result['ave_p_frame_psnr_u'] = 0 + log_result['ave_p_frame_psnr_v'] = 0 + log_result['ave_p_frame_msssim_y'] = 0 + log_result['ave_p_frame_msssim_u'] = 0 + log_result['ave_p_frame_msssim_v'] = 0 + log_result['ave_all_frame_bpp'] = (i_bits + p_bits) / (frame_num * frame_pixel_num) + log_result['ave_all_frame_psnr'] = (i_psnr + p_psnr) / frame_num + log_result['ave_all_frame_msssim'] = (i_ssim + p_ssim) / frame_num + if avg_encoding_time is not None and avg_decoding_time is not None: + log_result['avg_frame_encoding_time'] = avg_encoding_time + log_result['avg_frame_decoding_time'] = avg_decoding_time + if include_yuv: + log_result['ave_all_frame_psnr_y'] = (i_psnr_y + p_psnr_y) / frame_num + log_result['ave_all_frame_psnr_u'] = (i_psnr_u + p_psnr_u) / frame_num + log_result['ave_all_frame_psnr_v'] = (i_psnr_v + p_psnr_v) / frame_num + log_result['ave_all_frame_msssim_y'] = (i_ssim_y + p_ssim_y) / frame_num + log_result['ave_all_frame_msssim_u'] = (i_ssim_u + p_ssim_u) / frame_num + log_result['ave_all_frame_msssim_v'] = (i_ssim_v + p_ssim_v) / frame_num + + return log_result diff --git a/DCVC-RT/src/utils/metrics.py b/DCVC-RT/src/utils/metrics.py new file mode 100644 index 0000000..869a9ab --- /dev/null +++ b/DCVC-RT/src/utils/metrics.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +from scipy import signal +from scipy import ndimage + + +def fspecial_gauss(size, sigma): + x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] + g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) + return g / g.sum() + + +def calc_ssim(img1, img2, data_range=255): + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + size = 11 + sigma = 1.5 + window = fspecial_gauss(size, sigma) + K1 = 0.01 + K2 = 0.03 + C1 = (K1 * data_range)**2 + C2 = (K2 * data_range)**2 + mu1 = signal.fftconvolve(window, img1, mode='valid') + mu2 = signal.fftconvolve(window, img2, mode='valid') + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = signal.fftconvolve(window, img1 * img1, mode='valid') - mu1_sq + sigma2_sq = signal.fftconvolve(window, img2 * img2, mode='valid') - mu2_sq + sigma12 = signal.fftconvolve(window, img1 * img2, mode='valid') - mu1_mu2 + + return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)), + (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) + + +def calc_msssim(img1, img2, data_range=255): + ''' + img1 and img2 are 2D arrays + ''' + level = 5 + weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) + height, width = img1.shape + if height < 176 or width < 176: + # according to HM implementation + level = 4 + weight = np.array([0.0517, 0.3295, 0.3462, 0.2726]) + if height < 88 or width < 88: + assert False + downsample_filter = np.ones((2, 2)) / 4.0 + im1 = img1.astype(np.float64) + im2 = img2.astype(np.float64) + mssim = np.array([]) + mcs = np.array([]) + for _ in range(level): + ssim_map, cs_map = calc_ssim(im1, im2, data_range=data_range) + mssim = np.append(mssim, ssim_map.mean()) + mcs = np.append(mcs, cs_map.mean()) + filtered_im1 = ndimage.filters.convolve(im1, downsample_filter, + mode='reflect') + filtered_im2 = ndimage.filters.convolve(im2, downsample_filter, + mode='reflect') + im1 = filtered_im1[::2, ::2] + im2 = filtered_im2[::2, ::2] + return (np.prod(mcs[0:level - 1]**weight[0:level - 1]) * + (mssim[level - 1]**weight[level - 1])) + + +def calc_msssim_rgb(img1, img2, data_range=255): + ''' + img1 and img2 are arrays with 3xHxW + ''' + msssim = 0 + for i in range(3): + msssim += calc_msssim(img1[i, :, :], img2[i, :, :], data_range) + return msssim / 3 + + +def calc_psnr(img1, img2, data_range=255): + ''' + img1 and img2 are arrays with same shape + ''' + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean(np.square(img1 - img2)) + if np.isnan(mse) or np.isinf(mse): + return -999.9 + if mse > 1e-10: + psnr = 10 * np.log10(data_range * data_range / mse) + else: + psnr = 999.9 + if psnr > 99.9: + psnr = 99.9 + return psnr diff --git a/DCVC-RT/src/utils/stream_helper.py b/DCVC-RT/src/utils/stream_helper.py new file mode 100644 index 0000000..1450782 --- /dev/null +++ b/DCVC-RT/src/utils/stream_helper.py @@ -0,0 +1,217 @@ +# Copyright 2020 InterDigital Communications, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import struct +from pathlib import Path + + +def filesize(filepath: str) -> int: + if not Path(filepath).is_file(): + raise ValueError(f'Invalid file "{filepath}".') + return Path(filepath).stat().st_size + + +def write_uints(fd, values, fmt=">{:d}I"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + return len(values) * 4 + + +def write_uchars(fd, values, fmt=">{:d}B"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + return len(values) + + +def read_uints(fd, n, fmt=">{:d}I"): + sz = struct.calcsize("I") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def read_uchars(fd, n, fmt=">{:d}B"): + sz = struct.calcsize("B") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def write_bytes(fd, values, fmt=">{:d}s"): + if len(values) == 0: + return 0 + fd.write(struct.pack(fmt.format(len(values)), values)) + return len(values) + + +def read_bytes(fd, n, fmt=">{:d}s"): + sz = struct.calcsize("s") + return struct.unpack(fmt.format(n), fd.read(n * sz))[0] + + +def write_ushorts(fd, values, fmt=">{:d}H"): + fd.write(struct.pack(fmt.format(len(values)), *values)) + return len(values) * 2 + + +def read_ushorts(fd, n, fmt=">{:d}H"): + sz = struct.calcsize("H") + return struct.unpack(fmt.format(n), fd.read(n * sz)) + + +def write_uint_adaptive(f, a): + if a < (1 << 7): + a0 = (a >> 0) & 0xff + a0 = a0 | (0x00 << 7) + write_uchars(f, (a0,)) + return 1 + + if a < (1 << 14): + a0 = (a >> 0) & 0xff + a1 = (a >> 8) & 0xff + a1 = a1 | (0x02 << 6) + write_uchars(f, (a1, a0)) + return 2 + + assert a < (1 << 30) + a0 = (a >> 0) & 0xff + a1 = (a >> 8) & 0xff + a2 = (a >> 16) & 0xff + a3 = (a >> 24) & 0xff + a3 = a3 | (0x03 << 6) + write_uchars(f, (a3, a2, a1, a0)) + return 4 + + +def read_uint_adaptive(f): + a3 = read_uchars(f, 1)[0] + if (a3 >> 7) == 0: + return a3 + + a2 = read_uchars(f, 1)[0] + + if (a3 >> 6) == 0x02: + a3 = a3 & 0x3f + return (a3 << 8) + a2 + a3 = a3 & 0x3f + a1 = read_uchars(f, 1)[0] + a0 = read_uchars(f, 1)[0] + return (a3 << 24) + (a2 << 16) + (a1 << 8) + a0 + + +class NalType(enum.IntEnum): + NAL_SPS = 0 + NAL_I = 1 + NAL_P = 2 + + +class SPSHelper(): + def __init__(self): + super().__init__() + self.spss = [] + + def get_sps_id(self, target_sps): + min_id = -1 + for sps in self.spss: + if sps['height'] == target_sps['height'] and sps['width'] == target_sps['width'] and \ + sps['use_ada_i'] == target_sps['use_ada_i'] and \ + sps['ec_part'] == target_sps['ec_part']: + return sps['sps_id'], False + if sps['sps_id'] > min_id: + min_id = sps['sps_id'] + assert min_id < 15 + sps = target_sps.copy() + sps['sps_id'] = min_id + 1 + self.spss.append(sps) + return sps['sps_id'], True + + def add_sps_by_id(self, sps): + for i in range(len(self.spss)): + if self.spss[i]['sps_id'] == sps['sps_id']: + self.spss[i] = sps.copy() + return + self.spss.append(sps.copy()) + + def get_sps_by_id(self, sps_id): + for sps in self.spss: + if sps['sps_id'] == sps_id: + return sps + return None + + +def write_sps(f, sps): + # nal_type(4), sps_id(4) + # height (variable) + # width (vairable) + # 0(6), ec_part(1) use_ada_i(1) + assert sps['sps_id'] < 16 + assert sps['use_ada_i'] < 2 + written = 0 + flag = int((NalType.NAL_SPS << 4) + sps['sps_id']) + written += write_uchars(f, (flag,)) + written += write_uint_adaptive(f, sps['height']) + written += write_uint_adaptive(f, sps['width']) + flag = (sps['ec_part'] << 2) + sps['use_ada_i'] + written += write_uchars(f, (flag,)) + return written + + +def read_header(f): + header = {} + flag = read_uchars(f, 1)[0] + nal_type = flag >> 4 + header['nal_type'] = NalType(nal_type) + if nal_type < 3: + header['sps_id'] = flag & 0x0f + return header + + frame_num_minus1 = flag & 0x0f + frame_num = frame_num_minus1 + 1 + header['frame_num'] = frame_num + sps_ids = [] + for _ in range(0, frame_num, 2): + flag = read_uchars(f, 1)[0] + sps_ids.append(flag >> 4) + sps_ids.append(flag & 0x0f) + sps_ids = sps_ids[:frame_num] + header['sps_ids'] = sps_ids + return header + + +def read_sps_remaining(f, sps_id): + sps = {} + sps['sps_id'] = sps_id + sps['height'] = read_uint_adaptive(f) + sps['width'] = read_uint_adaptive(f) + flag = read_uchars(f, 1)[0] + sps['ec_part'] = (flag >> 2) & 0x01 + sps['use_ada_i'] = flag & 0x01 + return sps + + +def write_ip(f, is_i_frame, sps_id, qp, bit_stream): + written = 0 + flag = (int(NalType.NAL_I if is_i_frame else NalType.NAL_P) << 4) + sps_id + written += write_uchars(f, (flag,)) + assert qp < 256 and qp >= 0 + flag = qp + written += write_uchars(f, (flag,)) + # we write all the streams in the same file, thus, we need to write the per-frame length + # if packed independently, we do not need to write it + written += write_uint_adaptive(f, len(bit_stream)) + written += write_bytes(f, bit_stream) + return written + + +def read_ip_remaining(f): + flag = read_uchars(f, 1)[0] + qp = flag + stream_length = read_uint_adaptive(f) + bit_stream = read_bytes(f, stream_length) + return qp, bit_stream diff --git a/DCVC-RT/src/utils/transforms.py b/DCVC-RT/src/utils/transforms.py new file mode 100644 index 0000000..9b96d17 --- /dev/null +++ b/DCVC-RT/src/utils/transforms.py @@ -0,0 +1,63 @@ +import numpy as np +import scipy.ndimage +import torch +import torch.nn.functional as F + + +YCBCR_WEIGHTS = { + # Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b + "ITU-R_BT.709": (0.2126, 0.7152, 0.0722) +} + + +def ycbcr420_to_444_np(y, uv, order=0, separate=False): + ''' + y is 1xhxw Y float numpy array + uv is 2x(h/2)x(w/2) UV float numpy array + order: 0 nearest neighbor (default), 1: binear + return value is 3xhxw YCbCr float numpy array + ''' + uv = scipy.ndimage.zoom(uv, (1, 2, 2), order=order) + if separate: + return y, uv + yuv = np.concatenate((y, uv), axis=0) + return yuv + + +def rgb2ycbcr(rgb, is_bgr=False): + if is_bgr: + b, g, r = rgb.chunk(3, -3) + else: + r, g, b = rgb.chunk(3, -3) + Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] + y = Kr * r + Kg * g + Kb * b + cb = 0.5 * (b - y) / (1 - Kb) + 0.5 + cr = 0.5 * (r - y) / (1 - Kr) + 0.5 + ycbcr = torch.cat((y, cb, cr), dim=-3) + ycbcr = torch.clamp(ycbcr, 0., 1.) + return ycbcr + + +def ycbcr2rgb(ycbcr, is_bgr=False, clamp=True): + y, cb, cr = ycbcr.chunk(3, -3) + Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] + r = y + (2 - 2 * Kr) * (cr - 0.5) + b = y + (2 - 2 * Kb) * (cb - 0.5) + g = (y - Kr * r - Kb * b) / Kg + if is_bgr: + rgb = torch.cat((b, g, r), dim=-3) + else: + rgb = torch.cat((r, g, b), dim=-3) + if clamp: + rgb = torch.clamp(rgb, 0., 1.) + return rgb + + +def yuv_444_to_420(yuv): + def _downsample(tensor): + return F.avg_pool2d(tensor, kernel_size=2, stride=2) + + y = yuv[:, :1, :, :] + uv = yuv[:, 1:, :, :] + + return y, _downsample(uv) diff --git a/DCVC-RT/src/utils/video_reader.py b/DCVC-RT/src/utils/video_reader.py new file mode 100644 index 0000000..a251845 --- /dev/null +++ b/DCVC-RT/src/utils/video_reader.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +import numpy as np +from PIL import Image + + +class PNGReader(): + def __init__(self, src_path, width, height, start_num=1): + self.eof = False + self.src_path = src_path + self.width = width + self.height = height + pngs = os.listdir(self.src_path) + if 'im1.png' in pngs: + self.padding = 1 + elif 'im00001.png' in pngs: + self.padding = 5 + else: + raise ValueError('unknown image naming convention; please specify') + self.current_frame_index = start_num + + def read_one_frame(self): + # rgb: 3xhxw uint8 numpy array + if self.eof: + return None + + png_path = os.path.join(self.src_path, + f"im{str(self.current_frame_index).zfill(self.padding)}.png" + ) + if not os.path.exists(png_path): + self.eof = True + return None + + rgb = Image.open(png_path).convert('RGB') + rgb = np.asarray(rgb).astype(np.uint8).transpose(2, 0, 1) + _, height, width = rgb.shape + assert height == self.height + assert width == self.width + + self.current_frame_index += 1 + return rgb + + def close(self): + self.current_frame_index = 1 + + +class YUV420Reader(): + def __init__(self, src_path, width, height, skip_frame=0): + self.eof = False + if not src_path.endswith('.yuv'): + src_path = src_path + '.yuv' + self.src_path = src_path + + self.y_size = width * height + self.y_width = width + self.y_height = height + self.uv_size = width * height // 2 + self.uv_width = width // 2 + self.uv_height = height // 2 + # pylint: disable=R1732 + self.file = open(src_path, "rb") + # pylint: enable=R1732 + skipped_frame = 0 + while not self.eof and skipped_frame < skip_frame: + y = self.file.read(self.y_size) + uv = self.file.read(self.uv_size) + if not y or not uv: + self.eof = True + skipped_frame += 1 + + def read_one_frame(self): + # y: 1xhxw uint8 numpy array + # uv: 2x(h/2)x(w/2) uint8 numpy array + if self.eof: + return None, None + y = self.file.read(self.y_size) + uv = self.file.read(self.uv_size) + if not y or not uv: + self.eof = True + return None, None + y = np.frombuffer(y, dtype=np.uint8).copy().reshape(1, self.y_height, self.y_width) + uv = np.frombuffer(uv, dtype=np.uint8).copy().reshape(2, self.uv_height, self.uv_width) + + return y, uv + + def close(self): + self.file.close() diff --git a/DCVC-RT/src/utils/video_writer.py b/DCVC-RT/src/utils/video_writer.py new file mode 100644 index 0000000..d61089d --- /dev/null +++ b/DCVC-RT/src/utils/video_writer.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +from PIL import Image + + +class PNGWriter(): + def __init__(self, dst_path, width, height): + self.dst_path = dst_path + self.width = width + self.height = height + self.padding = 5 + self.current_frame_index = 1 + os.makedirs(dst_path, exist_ok=True) + + def write_one_frame(self, rgb): + # rgb: 3xhxw uint8 numpy array + rgb = rgb.transpose(1, 2, 0) + + png_path = os.path.join(self.dst_path, + f"im{str(self.current_frame_index).zfill(self.padding)}.png" + ) + Image.fromarray(rgb).save(png_path) + + self.current_frame_index += 1 + + def close(self): + self.current_frame_index = 1 + + +class YUV420Writer(): + def __init__(self, dst_path, width, height): + if not dst_path.endswith('.yuv'): + dst_path = dst_path + '/out.yuv' + self.dst_path = dst_path + self.width = width + self.height = height + + # pylint: disable=R1732 + self.file = open(dst_path, "wb") + # pylint: enable=R1732 + + def write_one_frame(self, y, uv): + # y: 1xhxw uint8 numpy array + # uv: 2x(h/2)x(w/2) uint8 numpy array + self.file.write(y.tobytes()) + self.file.write(uv.tobytes()) + + def close(self): + self.file.close() diff --git a/DCVC-RT/test_video.py b/DCVC-RT/test_video.py new file mode 100644 index 0000000..8658569 --- /dev/null +++ b/DCVC-RT/test_video.py @@ -0,0 +1,541 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import concurrent.futures +import io +import json +import multiprocessing +import os +import time + +import torch +import numpy as np +from tqdm import tqdm + +from src.layers.cuda_inference import replicate_pad +from src.models.video_model import DMC +from src.models.image_model import DMCI +from src.utils.common import str2bool, create_folder, generate_log_json, get_state_dict, \ + dump_json, set_torch_env +from src.utils.stream_helper import SPSHelper, NalType, write_sps, read_header, \ + read_sps_remaining, read_ip_remaining, write_ip +from src.utils.video_reader import PNGReader, YUV420Reader +from src.utils.video_writer import PNGWriter, YUV420Writer +from src.utils.metrics import calc_psnr, calc_msssim, calc_msssim_rgb +from src.utils.transforms import rgb2ycbcr, ycbcr2rgb, yuv_444_to_420, ycbcr420_to_444_np + + +def parse_args(): + parser = argparse.ArgumentParser(description="Example testing script") + + parser.add_argument('--force_zero_thres', type=float, default=None, required=False) + parser.add_argument('--model_path_i', type=str) + parser.add_argument('--model_path_p', type=str) + parser.add_argument('--rate_num', type=int, default=4) + parser.add_argument('--qp_i', type=int, nargs="+") + parser.add_argument('--qp_p', type=int, nargs="+") + parser.add_argument("--force_intra", type=str2bool, default=False) + parser.add_argument("--force_frame_num", type=int, default=-1) + parser.add_argument("--force_intra_period", type=int, default=-1) + parser.add_argument('--reset_interval', type=int, default=32, required=False) + parser.add_argument('--test_config', type=str, required=True) + parser.add_argument('--force_root_path', type=str, default=None, required=False) + parser.add_argument("--worker", "-w", type=int, default=1, help="worker number") + parser.add_argument("--cuda", type=str2bool, default=False) + parser.add_argument('--cuda_idx', type=int, nargs="+", help='GPU indexes to use') + parser.add_argument('--calc_ssim', type=str2bool, default=False, required=False) + parser.add_argument('--write_stream', type=str2bool, default=False) + parser.add_argument('--check_existing', type=str2bool, default=False) + parser.add_argument('--stream_path', type=str, default="out_bin") + parser.add_argument('--save_decoded_frame', type=str2bool, default=False) + parser.add_argument('--output_path', type=str, required=True) + parser.add_argument('--verbose_json', type=str2bool, default=False) + parser.add_argument('--verbose', type=int, default=0) + + args = parser.parse_args() + return args + + +def np_image_to_tensor(img, device): + image = torch.from_numpy(img).to(device=device).to(dtype=torch.float32) / 255.0 + image = image.unsqueeze(0) + return image + + +def get_src_reader(args): + if args['src_type'] == 'png': + src_reader = PNGReader(args['src_path'], args['src_width'], args['src_height']) + elif args['src_type'] == 'yuv420': + src_reader = YUV420Reader(args['src_path'], args['src_width'], args['src_height']) + return src_reader + + +def get_src_frame(args, src_reader, device): + if args['src_type'] == 'yuv420': + y, uv = src_reader.read_one_frame() + yuv = ycbcr420_to_444_np(y, uv) + x = np_image_to_tensor(yuv, device) + y = y[0, :, :] + u = uv[0, :, :] + v = uv[1, :, :] + rgb = None + else: + assert args['src_type'] == 'png' + rgb = src_reader.read_one_frame() + x = np_image_to_tensor(rgb, device) + x = rgb2ycbcr(x) + y, u, v = None, None, None + + x = x.to(torch.float16) + return x, y, u, v, rgb + + +def get_distortion(args, x_hat, y, u, v, rgb): + if args['src_type'] == 'yuv420': + y_rec, uv_rec = yuv_444_to_420(x_hat) + y_rec = torch.clamp(y_rec * 255, 0, 255).squeeze(0).cpu().numpy() + uv_rec = torch.clamp(uv_rec * 255, 0, 255).squeeze(0).cpu().numpy() + y_rec = y_rec[0, :, :] + u_rec = uv_rec[0, :, :] + v_rec = uv_rec[1, :, :] + psnr_y = calc_psnr(y, y_rec) + psnr_u = calc_psnr(u, u_rec) + psnr_v = calc_psnr(v, v_rec) + psnr = (6 * psnr_y + psnr_u + psnr_v) / 8 + if args['calc_ssim']: + ssim_y = calc_msssim(y, y_rec) + ssim_u = calc_msssim(u, u_rec) + ssim_v = calc_msssim(v, v_rec) + else: + ssim_y, ssim_u, ssim_v = 0., 0., 0. + ssim = (6 * ssim_y + ssim_u + ssim_v) / 8 + + curr_psnr = [psnr, psnr_y, psnr_u, psnr_v] + curr_ssim = [ssim, ssim_y, ssim_u, ssim_v] + else: + assert args['src_type'] == 'png' + rgb_rec = ycbcr2rgb(x_hat) + rgb_rec = torch.clamp(rgb_rec * 255, 0, 255).squeeze(0).cpu().numpy() + psnr = calc_psnr(rgb, rgb_rec) + if args['calc_ssim']: + msssim = calc_msssim_rgb(rgb, rgb_rec) + else: + msssim = 0. + curr_psnr = [psnr] + curr_ssim = [msssim] + return curr_psnr, curr_ssim + + +def run_one_point_with_stream(p_frame_net, i_frame_net, args): + if args['check_existing'] and os.path.exists(args['curr_json_path']) and \ + os.path.exists(args['curr_bin_path']): + with open(args['curr_json_path']) as f: + log_result = json.load(f) + if log_result['i_frame_num'] + log_result['p_frame_num'] == args['frame_num']: + return log_result + print(f"incorrect log for {args['curr_json_path']}, try to rerun.") + + frame_num = args['frame_num'] + save_decoded_frame = args['save_decoded_frame'] + verbose = args['verbose'] + reset_interval = args['reset_interval'] + intra_period = args['intra_period'] + verbose_json = args['verbose_json'] + device = next(i_frame_net.parameters()).device + + src_reader = get_src_reader(args) + pic_height = args['src_height'] + pic_width = args['src_width'] + padding_r, padding_b = DMCI.get_padding_size(pic_height, pic_width, 16) + + use_two_entropy_coders = pic_height * pic_width > 1280 * 720 + i_frame_net.set_use_two_entropy_coders(use_two_entropy_coders) + p_frame_net.set_use_two_entropy_coders(use_two_entropy_coders) + + frame_types = [] + psnrs = [] + msssims = [] + bits = [] + + start_time = time.time() + encoding_time = [] + decoding_time = [] + index_map = [0, 1, 0, 2, 0, 2, 0, 2] + + output_buff = io.BytesIO() + sps_helper = SPSHelper() + + p_frame_net.set_curr_poc(0) + with torch.no_grad(): + last_qp = 0 + for frame_idx in range(frame_num): + x, y, u, v, rgb = get_src_frame(args, src_reader, device) + + torch.cuda.synchronize(device=device) + frame_start_time = time.time() + + # pad if necessary + x_padded = replicate_pad(x, padding_b, padding_r) + + is_i_frame = False + if frame_idx == 0 or (intra_period > 0 and frame_idx % intra_period == 0): + is_i_frame = True + curr_qp = args['qp_i'] + sps = { + 'sps_id': -1, + 'height': pic_height, + 'width': pic_width, + 'ec_part': 1 if use_two_entropy_coders else 0, + 'use_ada_i': 0, + } + encoded = i_frame_net.compress(x_padded, args['qp_i']) + p_frame_net.clear_dpb() + p_frame_net.add_ref_frame(None, encoded['x_hat']) + frame_types.append(0) + else: + fa_idx = index_map[frame_idx % 8] + if reset_interval > 0 and frame_idx % reset_interval == 1: + use_ada_i = 1 + p_frame_net.prepare_feature_adaptor_i(last_qp) + else: + use_ada_i = 0 + curr_qp = p_frame_net.shift_qp(args['qp_p'], fa_idx) + sps = { + 'sps_id': -1, + 'height': pic_height, + 'width': pic_width, + 'ec_part': 1 if use_two_entropy_coders else 0, + 'use_ada_i': use_ada_i, + } + + encoded = p_frame_net.compress(x_padded, curr_qp) + last_qp = curr_qp + frame_types.append(1) + + sps_id, sps_new = sps_helper.get_sps_id(sps) + sps['sps_id'] = sps_id + sps_bytes = 0 + if sps_new: + sps_bytes = write_sps(output_buff, sps) + if verbose >= 2: + print("new sps", sps) + stream_bytes = write_ip(output_buff, is_i_frame, sps_id, curr_qp, encoded['bit_stream']) + bits.append(stream_bytes * 8 + sps_bytes * 8) + + torch.cuda.synchronize(device=device) + frame_end_time = time.time() + + frame_time = frame_end_time - frame_start_time + encoding_time.append(frame_time) + + if verbose >= 2: + print(f"frame {frame_idx} encoded, {frame_time * 1000:.3f} ms, " + f"bits: {bits[-1]}") + + src_reader.close() + with open(args['curr_bin_path'], "wb") as output_file: + bytes_buffer = output_buff.getbuffer() + output_file.write(bytes_buffer) + total_bytes = bytes_buffer.nbytes + bytes_buffer.release() + total_kbps = int(total_bytes * 8 / (frame_num / 30) / 1000) # assume 30 fps + output_buff.close() + sps_helper = SPSHelper() + input_file = open(args['curr_bin_path'], "rb") + with open(args['curr_bin_path'], "rb") as input_file: + input_buff = io.BytesIO(input_file.read()) + decoded_frame_number = 0 + src_reader = get_src_reader(args) + + if save_decoded_frame: + if args['src_type'] == 'png': + recon_writer = PNGWriter(args['bin_folder'], args['src_width'], args['src_height']) + elif args['src_type'] == 'yuv420': + output_yuv_path = args['curr_rec_path'].replace('.yuv', f'_{total_kbps}kbps.yuv') + recon_writer = YUV420Writer(output_yuv_path, args['src_width'], args['src_height']) + + p_frame_net.set_curr_poc(0) + with torch.no_grad(): + while decoded_frame_number < frame_num: + x, y, u, v, rgb = get_src_frame(args, src_reader, device) + torch.cuda.synchronize(device=device) + frame_start_time = time.time() + + header = read_header(input_buff) + while header['nal_type'] == NalType.NAL_SPS: + sps = read_sps_remaining(input_buff, header['sps_id']) + sps_helper.add_sps_by_id(sps) + if verbose >= 2: + print("new sps", sps) + header = read_header(input_buff) + continue + sps_id = header['sps_id'] + + sps = sps_helper.get_sps_by_id(sps_id) + qp, bit_stream = read_ip_remaining(input_buff) + + if header['nal_type'] == NalType.NAL_I: + decoded = i_frame_net.decompress(bit_stream, sps, qp) + p_frame_net.clear_dpb() + p_frame_net.add_ref_frame(None, decoded['x_hat']) + elif header['nal_type'] == NalType.NAL_P: + if sps['use_ada_i']: + p_frame_net.reset_ref_feature() + decoded = p_frame_net.decompress(bit_stream, sps, qp) + + recon_frame = decoded['x_hat'] + x_hat = recon_frame[:, :, :pic_height, :pic_width] + + torch.cuda.synchronize(device=device) + frame_end_time = time.time() + + frame_time = frame_end_time - frame_start_time + decoding_time.append(frame_time) + + curr_psnr, curr_ssim = get_distortion(args, x_hat, y, u, v, rgb) + psnrs.append(curr_psnr) + msssims.append(curr_ssim) + + if verbose >= 2: + stream_length = 0 if bit_stream is None else len(bit_stream) * 8 + print(f"frame {decoded_frame_number} decoded, {frame_time * 1000:.3f} ms, " + f"bits: {stream_length}, PSNR: {curr_psnr[0]:.4f} ") + + if save_decoded_frame: + if args['src_type'] == 'yuv420': + y_rec, uv_rec = yuv_444_to_420(x_hat) + y_rec = torch.clamp(y_rec * 255, 0, 255).round().to(dtype=torch.uint8) + y_rec = y_rec.squeeze(0).cpu().numpy() + uv_rec = torch.clamp(uv_rec * 255, 0, 255).to(dtype=torch.uint8) + uv_rec = uv_rec.squeeze(0).cpu().numpy() + recon_writer.write_one_frame(y_rec, uv_rec) + else: + assert args['src_type'] == 'png' + rgb_rec = ycbcr2rgb(x_hat) + rgb_rec = torch.clamp(rgb_rec * 255, 0, 255).round().to(dtype=torch.uint8) + rgb_rec = rgb_rec.squeeze(0).cpu().numpy() + recon_writer.write_one_frame(rgb_rec) + decoded_frame_number += 1 + input_buff.close() + src_reader.close() + + if save_decoded_frame: + recon_writer.close() + + test_time = time.time() - start_time + test_time_frame_numuber = len(encoding_time) + time_bypass_frame_num = 10 # bypass the first 10 frames as warmup + if verbose >= 1 and test_time_frame_numuber > time_bypass_frame_num: + encoding_time = encoding_time[time_bypass_frame_num:] + decoding_time = decoding_time[time_bypass_frame_num:] + avg_encoding_time = sum(encoding_time)/len(encoding_time) + avg_decoding_time = sum(decoding_time)/len(decoding_time) + print(f"encoding/decoding {test_time_frame_numuber} frames, " + f"average encoding time {avg_encoding_time * 1000:.3f} ms, " + f"average decoding time {avg_decoding_time * 1000:.3f} ms.") + else: + avg_encoding_time = None + avg_decoding_time = None + + log_result = generate_log_json(frame_num, pic_height * pic_width, test_time, + frame_types, bits, psnrs, msssims, verbose=verbose_json, + avg_encoding_time=avg_encoding_time, + avg_decoding_time=avg_decoding_time,) + with open(args['curr_json_path'], 'w') as fp: + json.dump(log_result, fp, indent=2) + return log_result + + +i_frame_net = None # the model is initialized after each process is spawn, thus OK for multiprocess +p_frame_net = None + + +def worker(args): + global i_frame_net + global p_frame_net + + sub_dir_name = args['seq'] + bin_folder = os.path.join(args['stream_path'], args['ds_name']) + assert args['write_stream'], "" + create_folder(bin_folder, True) + + args['src_path'] = os.path.join(args['dataset_path'], sub_dir_name) + args['bin_folder'] = bin_folder + args['curr_bin_path'] = os.path.join(bin_folder, + f"{args['seq']}_q{args['qp_i']}.bin") + args['curr_rec_path'] = args['curr_bin_path'].replace('.bin', '.yuv') + args['curr_json_path'] = args['curr_bin_path'].replace('.bin', '.json') + + result = run_one_point_with_stream(p_frame_net, i_frame_net, args) + + result['ds_name'] = args['ds_name'] + result['seq'] = args['seq'] + result['rate_idx'] = args['rate_idx'] + result['qp_i'] = args['qp_i'] + result['qp_p'] = args['qp_p'] if 'qp_p' in args else args['qp_i'] + + return result + + +def init_func(args, gpu_num): + set_torch_env() + + process_name = multiprocessing.current_process().name + process_idx = int(process_name[process_name.rfind('-') + 1:]) + gpu_id = -1 + if gpu_num > 0: + gpu_id = process_idx % gpu_num + if gpu_id >= 0: + if args.cuda_idx is not None: + gpu_id = args.cuda_idx[gpu_id] + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) + device = "cuda:0" + else: + device = "cpu" + + global i_frame_net + i_frame_net = DMCI() + i_state_dict = get_state_dict(args.model_path_i) + i_frame_net.load_state_dict(i_state_dict) + i_frame_net = i_frame_net.to(device) + i_frame_net.eval() + i_frame_net.update(args.force_zero_thres) + i_frame_net.half() + + global p_frame_net + p_frame_net = DMC() + if not args.force_intra: + p_state_dict = get_state_dict(args.model_path_p) + p_frame_net.load_state_dict(p_state_dict) + p_frame_net = p_frame_net.to(device) + p_frame_net.eval() + p_frame_net.update(args.force_zero_thres) + p_frame_net.half() + + +def main(): + begin_time = time.time() + + args = parse_args() + + if args.force_zero_thres is not None and args.force_zero_thres < 0: + args.force_zero_thres = None + + if args.cuda_idx is not None: + cuda_device = ','.join([str(s) for s in args.cuda_idx]) + os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device + + worker_num = args.worker + assert worker_num >= 1 + + with open(args.test_config) as f: + config = json.load(f) + + gpu_num = 0 + if args.cuda: + gpu_num = torch.cuda.device_count() + + multiprocessing.set_start_method("spawn") + threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num, + initializer=init_func, + initargs=(args, gpu_num)) + objs = [] + + count_frames = 0 + count_sequences = 0 + + rate_num = args.rate_num + qp_i = [] + if args.qp_i is not None: + assert len(args.qp_i) == rate_num + qp_i = args.qp_i + else: + assert 2 <= rate_num <= DMC.get_qp_num() + for i in np.linspace(0, DMC.get_qp_num() - 1, num=rate_num): + qp_i.append(int(i+0.5)) + + if not args.force_intra: + if args.qp_p is not None: + assert len(args.qp_p) == rate_num + qp_p = args.qp_p + else: + qp_p = qp_i + + print(f"testing {rate_num} rates, using qp: ", end='') + for q in qp_i: + print(f"{q}, ", end='') + print() + + root_path = args.force_root_path if args.force_root_path is not None else config['root_path'] + config = config['test_classes'] + for ds_name in config: + if config[ds_name]['test'] == 0: + continue + for seq in config[ds_name]['sequences']: + count_sequences += 1 + for rate_idx in range(rate_num): + cur_args = {} + cur_args['rate_idx'] = rate_idx + cur_args['qp_i'] = qp_i[rate_idx] + if not args.force_intra: + cur_args['qp_p'] = qp_p[rate_idx] + cur_args['force_intra'] = args.force_intra + cur_args['reset_interval'] = args.reset_interval + cur_args['seq'] = seq + cur_args['src_type'] = config[ds_name]['src_type'] + cur_args['src_height'] = config[ds_name]['sequences'][seq]['height'] + cur_args['src_width'] = config[ds_name]['sequences'][seq]['width'] + cur_args['intra_period'] = config[ds_name]['sequences'][seq]['intra_period'] + if args.force_intra: + cur_args['intra_period'] = 1 + if args.force_intra_period > 0: + cur_args['intra_period'] = args.force_intra_period + cur_args['frame_num'] = config[ds_name]['sequences'][seq]['frames'] + if args.force_frame_num > 0: + cur_args['frame_num'] = args.force_frame_num + cur_args['calc_ssim'] = args.calc_ssim + cur_args['dataset_path'] = os.path.join(root_path, config[ds_name]['base_path']) + cur_args['write_stream'] = args.write_stream + cur_args['check_existing'] = args.check_existing + cur_args['stream_path'] = args.stream_path + cur_args['save_decoded_frame'] = args.save_decoded_frame + cur_args['ds_name'] = ds_name + cur_args['verbose'] = args.verbose + cur_args['verbose_json'] = args.verbose_json + + count_frames += cur_args['frame_num'] + + obj = threadpool_executor.submit(worker, cur_args) + objs.append(obj) + + results = [] + for obj in tqdm(objs): + result = obj.result() + results.append(result) + + log_result = {} + for ds_name in config: + if config[ds_name]['test'] == 0: + continue + log_result[ds_name] = {} + for seq in config[ds_name]['sequences']: + log_result[ds_name][seq] = {} + + for res in results: + log_result[res['ds_name']][res['seq']][f"{res['rate_idx']:03d}"] = res + + out_json_dir = os.path.dirname(args.output_path) + if len(out_json_dir) > 0: + create_folder(out_json_dir, True) + with open(args.output_path, 'w') as fp: + dump_json(log_result, fp, float_digits=6, indent=2) + + total_minutes = (time.time() - begin_time) / 60 + print('Test finished') + print(f'Tested {count_frames} frames from {count_sequences} sequences') + print(f'Total elapsed time: {total_minutes:.1f} min') + + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md index eb992ae..72a927c 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ Official Pytorch implementation for Neural Video and Image Compression including * DCVC-FM: [Neural Video Compression with **F**eature **M**odulation](https://arxiv.org/abs/2402.17414), CVPR 2024, in [this folder](./DCVC-FM/). - The first end-to-end neural video codec to exceed ECM using the highest compression ratio low delay configuration with only one intra frame, in terms of PSNR for both YUV420 content and RGB content in a single model. - The first end-to-end neural video codec that support a large quality and bitrate range in a single model. + * DCVC-RT: [Towards Practical **R**eal-**T**ime Neural Video Compression](https://arxiv.org/abs/2502.20762), CVPR 2025, in [this folder](./DCVC-RT/). + - The first end-to-end neural video codec achieving 100+ fps encoding/decoding for 1080p on Nvidia A100 GPU. The overall compression ratio is comparable with DCVC-FM. * Neural Image Codec * [EVC: Towards Real-Time Neural Image Compression with Mask Decay](https://openreview.net/forum?id=XUxad2Gj40n), ICLR 2023, in [this folder](./EVC/). @@ -77,6 +79,14 @@ If you find this work useful for your research, please cite: year={2024} } +@inproceedings{jia2025towards, + title={Towards Practical Real-Time Neural Video Compression}, + author={Jia, Zhaoyang and Li, Bin and Li, Jiahao and Xie, Wenxuan and Qi, Linfeng and Li, Houqiang and Lu, Yan}, + booktitle={{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, + {CVPR} 2025, Nashville, TN, USA, June 11-25, 2024}, + year={2025} +} + @inproceedings{wang2023EVC, title={EVC: Towards Real-Time Neural Image Compression with Mask Decay}, author={Wang, Guo-Hua and Li, Jiahao and Li, Bin and Lu, Yan},