From 9ae4f9c2ecec15151dd4d4934c8cb45bd54b2e27 Mon Sep 17 00:00:00 2001 From: Yajie Yan Date: Mon, 9 Mar 2026 15:45:42 -0700 Subject: [PATCH] Add DataLayout support for pyvrs2 VRSWriter Summary: * **rename of zstd**: needed because OSS version only expose ZSTD_MEDIUM via pybind. internal version still works b.c. it exposes both ZSTD_MEDIUM and Zmedium. Differential Revision: D95732614 --- csrc/writer/VRSWriter.cpp | 11 + .../datalayouts/AriaGen2ImageDataLayout.cpp | 52 ++++ .../datalayouts/AriaGen2ImageDataLayout.h | 135 ++++++++++ pyvrs/__init__.py | 6 + pyvrs/writer.py | 4 +- test/pyvrs_writer_test.py | 243 ++++++++++++++++++ 6 files changed, 449 insertions(+), 2 deletions(-) create mode 100644 csrc/writer/datalayouts/AriaGen2ImageDataLayout.cpp create mode 100644 csrc/writer/datalayouts/AriaGen2ImageDataLayout.h create mode 100644 test/pyvrs_writer_test.py diff --git a/csrc/writer/VRSWriter.cpp b/csrc/writer/VRSWriter.cpp index a6d745b..27008d7 100644 --- a/csrc/writer/VRSWriter.cpp +++ b/csrc/writer/VRSWriter.cpp @@ -40,6 +40,7 @@ #include "StreamFactory.h" // Open source DataLayout definitions +#include "datalayouts/AriaGen2ImageDataLayout.h" #include "datalayouts/SampleDataLayout.h" namespace py = pybind11; @@ -74,6 +75,16 @@ void VRSWriter::init() { "sample_with_image", createSampleStreamWithImage); StreamFactory::getInstance().registerStreamCreationFunction( "sample_with_multiple_data_layout", createSampleStreamWithMultipleDataLayout); + // Aria Gen2 camera streams with correct RecordableTypeId + H.265 content block + StreamFactory::getInstance().registerFlavoredStreamCreationFunction( + "aria_gen2_rgb_camera", [](const string& flavor) { + return createAriaGen2ImageStream( + flavor, RecordableTypeId::RgbCameraRecordableClass, "H.265"); + }); + StreamFactory::getInstance().registerFlavoredStreamCreationFunction( + "aria_gen2_slam_camera", [](const string& flavor) { + return createAriaGen2ImageStream(flavor, RecordableTypeId::SlamCameraData, "H.265"); + }); /// Register open source stream writers (end) #if IS_VRS_FB_INTERNAL() diff --git a/csrc/writer/datalayouts/AriaGen2ImageDataLayout.cpp b/csrc/writer/datalayouts/AriaGen2ImageDataLayout.cpp new file mode 100644 index 0000000..1e1ebfe --- /dev/null +++ b/csrc/writer/datalayouts/AriaGen2ImageDataLayout.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "AriaGen2ImageDataLayout.h" + +#include + +#include "../PyRecordable.h" + +using namespace vrs; + +namespace pyvrs { + +std::unique_ptr createAriaGen2ImageStream( + const std::string& flavor, + RecordableTypeId typeId, + const std::string& codec) { + constexpr uint32_t kVersion = 2; + + auto configurationRecordFormat = std::make_unique( + Record::Type::CONFIGURATION, + kVersion, + std::make_unique(/*allocateVideoFields=*/true)); + + auto dataContentBlocks = codec == "H.265" + ? std::vector{ContentBlock("H.265", ImageContentBlockSpec::kQualityUndefined)} + : std::vector{ContentBlock(ImageFormat::RAW)}; + + auto dataRecordFormat = std::make_unique( + Record::Type::DATA, + kVersion, + std::make_unique(/*allocateVideoFields=*/true), + dataContentBlocks); + + return std::make_unique( + typeId, flavor, std::move(configurationRecordFormat), std::move(dataRecordFormat)); +} + +} // namespace pyvrs diff --git a/csrc/writer/datalayouts/AriaGen2ImageDataLayout.h b/csrc/writer/datalayouts/AriaGen2ImageDataLayout.h new file mode 100644 index 0000000..a2cfe7e --- /dev/null +++ b/csrc/writer/datalayouts/AriaGen2ImageDataLayout.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Aria Gen2 camera DataLayouts for OSS pyvrs writer. +// Vendored from arvr/libraries/visiontypes/vrs/data_layouts/ImageDataLayout.h +// with namespace changed from visiontypes::detail to pyvrs. + +#pragma once + +#include + +#include +#include +#include + +namespace pyvrs { + +using vrs::OptionalDataPieces; +using vrs::datalayout_conventions::ImageSpecType; + +// Additional fields to enable in ImageSensorConfigurationLayout when data was +// encoded as video before recording. +struct VideoConfigurationFields { + vrs::DataPieceString videoCodecName{vrs::datalayout_conventions::kImageCodecName}; +}; + +struct ImageSensorConfigurationLayout : public vrs::AutoDataLayout { + static constexpr uint32_t kVersion = 2; + + explicit ImageSensorConfigurationLayout(bool allocateVideoFields = false) + : videoConfigurationFields(allocateVideoFields) {} + + vrs::DataPieceString deviceType{"device_type"}; + vrs::DataPieceString deviceVersion{"device_version"}; + vrs::DataPieceString deviceSerial{"device_serial"}; + + vrs::DataPieceValue cameraId{"camera_id"}; + vrs::DataPieceValue streamType{"stream_type"}; + vrs::DataPieceValue streamIndex{"stream_index"}; + + vrs::DataPieceString sensorModel{"sensor_model"}; + vrs::DataPieceString sensorSerial{"sensor_serial"}; + + vrs::DataPieceValue nominalRateHz{"nominal_rate"}; + + vrs::DataPieceValue imageWidth{vrs::datalayout_conventions::kImageWidth}; + vrs::DataPieceValue imageHeight{vrs::datalayout_conventions::kImageHeight}; + vrs::DataPieceValue imageStride{vrs::datalayout_conventions::kImageStride}; + vrs::DataPieceValue imageStride2{vrs::datalayout_conventions::kImageStride2}; + vrs::DataPieceValue pixelFormat{vrs::datalayout_conventions::kImagePixelFormat}; + vrs::DataPieceValue plane2OffsetRows{"image_plane_2_offset_rows"}; + vrs::DataPieceValue plane3OffsetRows{"image_plane_3_offset_rows"}; + + vrs::DataPieceValue imageOrientation{"image_orientation"}; + vrs::DataPieceValue shutterDirection{"shutter_direction"}; + + vrs::DataPieceValue exposureDurationMin{"exposure_duration.min"}; + vrs::DataPieceValue exposureDurationMax{"exposure_duration.max"}; + + vrs::DataPieceValue gainMin{"gain.min"}; + vrs::DataPieceValue gainMax{"gain.max"}; + + vrs::DataPieceValue gammaFactor{"gamma_factor"}; + + vrs::DataPieceString factoryCalibration{"factory_calibration"}; + vrs::DataPieceString onlineCalibration{"online_calibration"}; + + vrs::DataPieceString description{"description"}; + + vrs::DataPieceString cameraMuxModeName{"camera_mux_mode_name"}; + + const OptionalDataPieces videoConfigurationFields; + + vrs::AutoDataLayoutEnd end; +}; + +// Additional fields to enable in ImageDataLayout when data was encoded as video +// before recording. +struct VideoDataFields { + vrs::DataPieceValue keyFrameTimestamp{ + vrs::datalayout_conventions::kImageKeyFrameTimeStamp}; + vrs::DataPieceValue keyFrameIndex{ + vrs::datalayout_conventions::kImageKeyFrameIndex}; +}; + +struct ImageDataLayout : public vrs::AutoDataLayout { + static constexpr uint32_t kVersion = 2; + + explicit ImageDataLayout(bool allocateVideoFields = false) + : videoDataFields(allocateVideoFields) {} + + vrs::DataPieceValue groupId{"group_id"}; + vrs::DataPieceValue groupMask{"group_mask"}; + vrs::DataPieceValue streamIndexMask{"stream_index_mask"}; + vrs::DataPieceValue frameNumber{"frame_number"}; + vrs::DataPieceValue frameTag{"frame_tag"}; + vrs::DataPieceValue exposureDuration{"exposure_duration_s"}; + vrs::DataPieceValue gain{"gain"}; + vrs::DataPieceValue readoutDurationSeconds{"readout_duration_s"}; + vrs::DataPieceValue captureTimestampNs{"capture_timestamp_ns"}; + vrs::DataPieceValue captureTimestampInProcessingClockDomainNs{ + "capture_timestamp_in_processing_clock_domain_ns"}; + vrs::DataPieceValue arrivalTimestampNs{"arrival_timestamp_ns"}; + vrs::DataPieceValue processingStartTimestampNs{"processing_start_timestamp_ns"}; + vrs::DataPieceValue temperature{"temperature_deg_c"}; + vrs::DataPieceVector imageMetadata{"image_metadata"}; + + const OptionalDataPieces videoDataFields; + + vrs::DataPieceValue focusDistanceMm{"focus_distance_mm", -1.0}; + + vrs::AutoDataLayoutEnd end; +}; + +class PyStream; + +std::unique_ptr createAriaGen2ImageStream( + const std::string& flavor, + vrs::RecordableTypeId typeId, + const std::string& codec = "H.265"); + +} // namespace pyvrs diff --git a/pyvrs/__init__.py b/pyvrs/__init__.py index fba80a7..d5a8442 100755 --- a/pyvrs/__init__.py +++ b/pyvrs/__init__.py @@ -35,12 +35,15 @@ recordable_type_id_name, RecordableId, RecordableTypeId, + RecordFormat, records_checksum, RecordType, + Stream, StreamNotFoundError, TimestampNotFoundError, verbatim_checksum, VRSRecord, + Writer, ) from .reader import AsyncVRSReader, SyncVRSReader @@ -69,10 +72,13 @@ "recordable_type_id_name", "RecordableId", "RecordableTypeId", + "RecordFormat", "records_checksum", "RecordType", + "Stream", "StreamNotFoundError", "TimestampNotFoundError", "verbatim_checksum", "VRSRecord", + "Writer", ] diff --git a/pyvrs/writer.py b/pyvrs/writer.py index 486d60d..74bf2be 100644 --- a/pyvrs/writer.py +++ b/pyvrs/writer.py @@ -57,7 +57,7 @@ def create_stream( self, name: str, flavor: str = "", - compression: CompressionPreset = CompressionPreset.Zmedium, + compression: CompressionPreset = CompressionPreset.ZSTD_MEDIUM, ) -> "VRSStream": if len(flavor) > 0: return VRSStream( @@ -121,7 +121,7 @@ def __init__( self, stream: Stream, writer: VRSWriter, - compression: CompressionPreset = CompressionPreset.Zmedium, + compression: CompressionPreset = CompressionPreset.ZSTD_MEDIUM, ) -> None: self.stream = stream self.stream.setCompression(compression) diff --git a/test/pyvrs_writer_test.py b/test/pyvrs_writer_test.py new file mode 100644 index 0000000..96921ec --- /dev/null +++ b/test/pyvrs_writer_test.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import pyvrs +from pyvrs.writer import VRSWriter + + +class TestWriterImports(unittest.TestCase): + """Verify that writer-related symbols are importable from pyvrs.""" + + def test_import_writer_symbols(self): + self.assertTrue(hasattr(pyvrs, "Writer")) + self.assertTrue(hasattr(pyvrs, "Stream")) + self.assertTrue(hasattr(pyvrs, "RecordFormat")) + self.assertTrue(hasattr(pyvrs, "CompressionPreset")) + + def test_compression_preset_has_zstd_medium(self): + self.assertTrue(hasattr(pyvrs.CompressionPreset, "ZSTD_MEDIUM")) + + +class TestSampleStreamWriter(unittest.TestCase): + """Test creating VRS files with sample streams.""" + + def test_create_sample_stream(self): + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + stream = writer.create_stream("sample") + self.assertIsNotNone(stream) + + stream_id = stream.get_stream_id() + self.assertIsInstance(stream_id, str) + self.assertGreater(len(stream_id), 0) + + metadata = stream.get_config_record_metadata() + self.assertIsNotNone(metadata) + + metadata.image_width = 640 + metadata.image_height = 480 + metadata.image_pixel_format = 1 + + stream.create_config_record(0.0, metadata) + writer.flush_records(0.0) + + data_meta = stream.get_data_record_metadata() + data_meta.room_temperature = 22.5 + data_meta.camera_id = 1 + stream.create_data_record(1.0, data_meta) + writer.flush_records(1.0) + + writer.close() + + self.assertTrue(os.path.exists(filepath)) + self.assertGreater(os.path.getsize(filepath), 0) + + # Verify we can read it back + reader = pyvrs.SyncVRSReader(filepath) + stream_ids = reader.stream_ids + self.assertGreater(len(stream_ids), 0) + + os.unlink(filepath) + + def test_create_sample_stream_with_flavor(self): + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + stream = writer.create_stream("flavored_sample", flavor="test_flavor") + self.assertIsNotNone(stream) + + metadata = stream.get_config_record_metadata() + metadata.image_width = 320 + metadata.image_height = 240 + metadata.image_pixel_format = 1 + stream.create_config_record(0.0, metadata) + + writer.flush_records(0.0) + writer.close() + + self.assertTrue(os.path.exists(filepath)) + os.unlink(filepath) + + +class TestAriaGen2StreamWriter(unittest.TestCase): + """Test creating VRS files with Aria Gen2 camera streams.""" + + def test_create_aria_gen2_rgb_camera_stream(self): + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + stream = writer.create_stream("aria_gen2_rgb_camera", flavor="camera-rgb") + self.assertIsNotNone(stream) + + stream_id = stream.get_stream_id() + self.assertIsInstance(stream_id, str) + # RGB camera should have RecordableTypeId 214 + self.assertIn("214", stream_id) + + writer.close() + if os.path.exists(filepath): + os.unlink(filepath) + + def test_create_aria_gen2_slam_camera_stream(self): + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + stream = writer.create_stream( + "aria_gen2_slam_camera", flavor="camera-slam-left" + ) + self.assertIsNotNone(stream) + + stream_id = stream.get_stream_id() + self.assertIsInstance(stream_id, str) + # SLAM camera should have RecordableTypeId 1201 + self.assertIn("1201", stream_id) + + writer.close() + if os.path.exists(filepath): + os.unlink(filepath) + + def test_aria_gen2_rgb_camera_write_with_image(self): + """Test writing config + data records with H.265 content block.""" + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + stream = writer.create_stream("aria_gen2_rgb_camera", flavor="camera-rgb") + + # Write configuration record + config_meta = stream.get_config_record_metadata() + config_meta.image_width = 1408 + config_meta.image_height = 1408 + config_meta.image_pixel_format = 200 # YUV_420_NV21 + config_meta.image_codec_name = "H.265" + stream.create_config_record(0.0, config_meta) + writer.flush_records(0.0) + + # Write data record with fake encoded bytes + data_meta = stream.get_data_record_metadata() + data_meta.capture_timestamp_ns = 1000000000 + fake_h265_bytes = np.zeros(1024, dtype=np.uint8) + stream.create_data_record(1.0, data_meta, fake_h265_bytes) + writer.flush_records(1.0) + + writer.close() + + self.assertTrue(os.path.exists(filepath)) + self.assertGreater(os.path.getsize(filepath), 0) + + # Verify we can open it and see the stream + reader = pyvrs.SyncVRSReader(filepath) + stream_ids = reader.stream_ids + self.assertGreater(len(stream_ids), 0) + + os.unlink(filepath) + + def test_multiple_aria_gen2_streams(self): + """Test creating a VRS file with multiple Aria Gen2 camera streams.""" + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + + rgb_stream = writer.create_stream("aria_gen2_rgb_camera", flavor="camera-rgb") + slam_left = writer.create_stream( + "aria_gen2_slam_camera", flavor="camera-slam-left" + ) + slam_right = writer.create_stream( + "aria_gen2_slam_camera", flavor="camera-slam-right" + ) + + # All should have unique stream IDs + ids = { + rgb_stream.get_stream_id(), + slam_left.get_stream_id(), + slam_right.get_stream_id(), + } + self.assertEqual(len(ids), 3, "All streams should have unique IDs") + + # Write config records for each + for stream in [rgb_stream, slam_left, slam_right]: + config_meta = stream.get_config_record_metadata() + config_meta.image_width = 640 + config_meta.image_height = 480 + stream.create_config_record(0.0, config_meta) + + writer.flush_records(0.0) + writer.close() + + self.assertTrue(os.path.exists(filepath)) + os.unlink(filepath) + + +class TestWriterErrors(unittest.TestCase): + """Test error handling in the writer.""" + + def test_unsupported_stream_name_raises(self): + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + with self.assertRaises(ValueError): + writer.create_stream("nonexistent_stream_type") + + writer.close() + + def test_data_before_config_raises(self): + with tempfile.NamedTemporaryFile(suffix=".vrs", delete=True) as f: + filepath = f.name + + writer = VRSWriter(filepath) + stream = writer.create_stream("sample") + + data_meta = stream.get_data_record_metadata() + data_meta.room_temperature = 20.0 + with self.assertRaises(Exception): + stream.create_data_record(1.0, data_meta) + + writer.close() + + +if __name__ == "__main__": + unittest.main()