From 838b3cdb358448a5ccbaf9c2703246bcdd4dcf48 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 18 Feb 2026 10:11:49 +0000 Subject: [PATCH] Support deserializing Python dataclass into structs / mappings --- CHANGELOG.md | 1 + src/de.rs | 219 ++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 211 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 744e6e8..7c25d03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - Bump MSRV to 1.83. - Update `pyo3` to 0.28. +- Support deserializing `dataclass` instances to struct-like Rust types. - Add `arbitrary_precision` feature ## 0.27.0 - 2025-11-07 diff --git a/src/de.rs b/src/de.rs index 03107ae..260d625 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,4 +1,5 @@ -use pyo3::{types::*, Bound}; +use pyo3::exceptions::PyKeyError; +use pyo3::{intern, types::*, Bound}; use serde::de::{self, IntoDeserializer}; use serde::Deserialize; @@ -7,7 +8,16 @@ use crate::error::{ErrorImpl, PythonizeError, Result}; #[cfg(feature = "arbitrary_precision")] const TOKEN: &str = "$serde_json::private::Number"; -/// Attempt to convert a Python object to an instance of `T` +/// Attempt to convert a Python object to an instance of `T`. +/// +/// Generally this only supports Python types that match `serde`'s object model well: +/// - integers (including arbitrary precision integers if the `arbitrary_precision` feature is enabled) +/// - floats +/// - strings +/// - bytes +/// - `collections.abc.Sequence` instances (as serde sequences) +/// - `collections.abc.Mapping` instances (as serde maps) +/// - dataclasses (as serde maps) pub fn depythonize<'a, 'py, T>(obj: &'a Bound<'py, PyAny>) -> Result where T: Deserialize<'a>, @@ -55,6 +65,14 @@ impl<'a, 'py> Depythonizer<'a, 'py> { PyMappingAccess::new(self.input.cast()?) } + fn dataclass_access(&self) -> Result>> { + if let Some(dc) = DataclassCandidate::try_new(self.input) { + Some(PyDataclassAccess::new(dc)).transpose() + } else { + Ok(None) + } + } + fn deserialize_any_int<'de, V>(&self, int: &Bound<'_, PyInt>, visitor: V) -> Result where V: de::Visitor<'de>, @@ -147,6 +165,8 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { self.deserialize_tuple(obj.len()?, visitor) } else if obj.cast::().is_ok() { self.deserialize_map(visitor) + } else if let Some(dc) = DataclassCandidate::try_new(obj) { + visitor.visit_map(PyDataclassAccess::new(dc)?) } else { Err(obj.get_type().qualname().map_or_else( |_| PythonizeError::unsupported_type("unknown"), @@ -293,7 +313,11 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { where V: de::Visitor<'de>, { - visitor.visit_map(self.dict_access()?) + if let Some(dc_access) = self.dataclass_access()? { + visitor.visit_map(dc_access) + } else { + visitor.visit_map(self.dict_access()?) + } } fn deserialize_struct( @@ -470,6 +494,79 @@ impl<'de> de::MapAccess<'de> for PyMappingAccess<'_> { } } +/// Intermediate structure used to denote that `obj` is a dataclass with `fields`. +struct DataclassCandidate<'a, 'py> { + obj: &'a Bound<'py, PyAny>, + fields: Bound<'py, PyAny>, +} + +impl<'a, 'py> DataclassCandidate<'a, 'py> { + fn try_new(obj: &'a Bound<'py, PyAny>) -> Option { + let fields = obj + .getattr_opt(intern!(obj.py(), "__dataclass_fields__")) + .ok() + .flatten()?; + Some(Self { obj, fields }) + } +} + +struct PyDataclassAccess<'py> { + fields: Bound<'py, PyList>, + dict: Bound<'py, PyDict>, + field_idx: usize, + val_idx: usize, + len: usize, +} + +impl<'py> PyDataclassAccess<'py> { + fn new(dc: DataclassCandidate<'_, 'py>) -> Result { + let fields = dc.fields.cast::()?.keys(); + let dict = dc + .obj + .getattr(intern!(dc.obj.py(), "__dict__"))? + .cast_into()?; + let len = fields.len(); + Ok(Self { + fields, + dict, + field_idx: 0, + val_idx: 0, + len, + }) + } +} + +impl<'de> de::MapAccess<'de> for PyDataclassAccess<'_> { + type Error = PythonizeError; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: de::DeserializeSeed<'de>, + { + if self.field_idx < self.len { + let item = self.fields.get_item(self.field_idx)?; + self.field_idx += 1; + seed.deserialize(&mut Depythonizer::from_object(&item)) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + let key = self.fields.get_item(self.val_idx)?; + let value = self + .dict + .get_item(&key)? + .ok_or_else(|| PyKeyError::new_err(key.unbind()))?; + self.val_idx += 1; + seed.deserialize(&mut Depythonizer::from_object(&value)) + } +} + struct PyEnumAccess<'a, 'py> { de: Depythonizer<'a, 'py>, variant: Bound<'py, PyString>, @@ -558,7 +655,7 @@ impl<'de> de::MapAccess<'de> for NumberDeserializer { #[cfg(test)] mod test { - use std::ffi::CStr; + use std::{collections::HashMap, ffi::CStr}; use super::*; use crate::error::ErrorImpl; @@ -572,14 +669,21 @@ mod test { { Python::attach(|py| { let obj = py.eval(code, None, None).unwrap(); - let actual: T = depythonize(&obj).unwrap(); - assert_eq!(&actual, expected); - - let actual_json: JsonValue = depythonize(&obj).unwrap(); - assert_eq!(&actual_json, expected_json); + test_de_with_obj(&obj, expected, expected_json); }); } + fn test_de_with_obj(obj: &Bound<'_, PyAny>, expected: &T, expected_json: &JsonValue) + where + T: de::DeserializeOwned + PartialEq + std::fmt::Debug, + { + let actual: T = depythonize(obj).unwrap(); + assert_eq!(&actual, expected); + + let actual_json: JsonValue = depythonize(obj).unwrap(); + assert_eq!(&actual_json, expected_json); + } + #[test] fn test_empty_struct() { #[derive(Debug, Deserialize, PartialEq)] @@ -930,4 +1034,101 @@ mod test { )); }); } + + #[test] + fn test_dataclass() { + let code = c"\ +from dataclasses import dataclass + +@dataclass +class Point: + x: int + y: int + +point = Point(1, 2)"; + + #[derive(Debug, Deserialize, PartialEq)] + struct Point { + x: i32, + y: i32, + } + + let expected = Point { x: 1, y: 2 }; + let expected_json = json!({"x": 1, "y": 2}); + + Python::attach(|py| { + let locals = PyDict::new(py); + py.run(code, None, Some(&locals)).unwrap(); + let obj = locals.get_item("point").unwrap().unwrap(); + test_de_with_obj(&obj, &expected, &expected_json); + + let map: HashMap = depythonize(&obj).unwrap(); + assert_eq!(map.len(), 2); + assert_eq!(*map.get("x").unwrap(), 1); + assert_eq!(*map.get("y").unwrap(), 2); + }); + } + + #[test] + fn test_dataclass_missing_field() { + let code = c"\ +from dataclasses import dataclass + +@dataclass +class Point: + x: int + y: int + +point = Point(1, 2)"; + + #[derive(Debug, Deserialize, PartialEq)] + struct Point { + x: i32, + y: i32, + z: i32, + } + + Python::attach(|py| { + let locals = PyDict::new(py); + py.run(code, None, Some(&locals)).unwrap(); + let obj = locals.get_item("point").unwrap().unwrap(); + let err = depythonize::(&obj).unwrap_err(); + assert!(matches!( + *err.inner, + ErrorImpl::Message(msg) if msg == "missing field `z`" + )); + }); + } + + #[test] + fn test_dataclass_extra_field() { + let code = c"\ +from dataclasses import dataclass + +@dataclass +class Point: + x: int + y: int + z: int + +point = Point(1, 2, 3)"; + + #[derive(Debug, Deserialize, PartialEq)] + #[serde(deny_unknown_fields)] + struct Point { + x: i32, + y: i32, + } + + Python::attach(|py| { + let locals = PyDict::new(py); + py.run(code, None, Some(&locals)).unwrap(); + let obj = locals.get_item("point").unwrap().unwrap(); + let err = depythonize::(&obj).unwrap_err(); + assert!(matches!( + *err.inner, + ErrorImpl::Message(msg) if msg == "unknown field `z`, expected `x` or `y`" + )); + }); + } }