diff --git a/src/certificate.rs b/src/certificate.rs index 77f2be4..560a8f4 100644 --- a/src/certificate.rs +++ b/src/certificate.rs @@ -80,7 +80,7 @@ pub fn generate_self_signed_certificate() -> Result, } impl std::fmt::Debug for EcdsaSigningKey { @@ -55,6 +56,16 @@ impl SigningKey for EcdsaSigningKey { panic!("Unsupported signing algorithm") } } + + fn clone_box(&self) -> Box { + let key_pair = EcdsaKeyPair::from_pkcs8(self.signing_algorithm, &self.key_der) + .expect("Re-parsing key should not fail"); + Box::new(EcdsaSigningKey { + key_pair, + signing_algorithm: self.signing_algorithm, + key_der: self.key_der.clone(), + }) + } } /// Key provider implementation. @@ -68,12 +79,14 @@ impl KeyProvider for AwsLcKeyProvider { return Ok(Box::new(EcdsaSigningKey { key_pair, signing_algorithm: &ECDSA_P256_SHA256_ASN1_SIGNING, + key_der: key_der.to_vec(), })); } if let Ok(key_pair) = EcdsaKeyPair::from_pkcs8(&ECDSA_P384_SHA384_ASN1_SIGNING, key_der) { return Ok(Box::new(EcdsaSigningKey { key_pair, signing_algorithm: &ECDSA_P384_SHA384_ASN1_SIGNING, + key_der: key_der.to_vec(), })); } @@ -124,6 +137,7 @@ impl KeyProvider for AwsLcKeyProvider { return Ok(Box::new(EcdsaSigningKey { key_pair, signing_algorithm: &ECDSA_P256_SHA256_ASN1_SIGNING, + key_der: pkcs8_der.clone(), })); } } @@ -136,6 +150,7 @@ impl KeyProvider for AwsLcKeyProvider { return Ok(Box::new(EcdsaSigningKey { key_pair, signing_algorithm: &ECDSA_P384_SHA384_ASN1_SIGNING, + key_der: pkcs8_der.clone(), })); } } diff --git a/src/crypto/provider.rs b/src/crypto/provider.rs index 5def5fd..8210d20 100644 --- a/src/crypto/provider.rs +++ b/src/crypto/provider.rs @@ -209,6 +209,10 @@ pub trait SigningKey: CryptoSafe { /// Default hash algorithm for this key. fn hash_algorithm(&self) -> HashAlgorithm; + + /// Clone this signing key into a new boxed instance. + /// Used to support cloning `DtlsCertificate` when using external signing keys. + fn clone_box(&self) -> Box; } /// Active key exchange instance (ephemeral keypair for one handshake). diff --git a/src/crypto/rust_crypto/sign.rs b/src/crypto/rust_crypto/sign.rs index 15cf044..40714e7 100644 --- a/src/crypto/rust_crypto/sign.rs +++ b/src/crypto/rust_crypto/sign.rs @@ -83,6 +83,13 @@ impl SigningKeyTrait for EcdsaSigningKey { EcdsaSigningKey::P384(_) => HashAlgorithm::SHA384, } } + + fn clone_box(&self) -> Box { + match self { + EcdsaSigningKey::P256(key) => Box::new(EcdsaSigningKey::P256(key.clone())), + EcdsaSigningKey::P384(key) => Box::new(EcdsaSigningKey::P384(key.clone())), + } + } } /// Key provider implementation. diff --git a/src/dtls12/context.rs b/src/dtls12/context.rs index 59a7fcd..be7e5b3 100644 --- a/src/dtls12/context.rs +++ b/src/dtls12/context.rs @@ -73,24 +73,28 @@ impl CryptoContext { /// Create a new crypto context pub fn new( certificate: Vec, - private_key_bytes: Vec, + private_key: crate::DtlsCertificatePrivateKey, config: Arc, ) -> Self { - // Validate that we have a certificate and private key + // Validate that we have a certificate if certificate.is_empty() { panic!("Client certificate cannot be empty"); } - if private_key_bytes.is_empty() { - panic!("Client private key cannot be empty"); - } - - // Parse the private key using the provider - let private_key = config - .crypto_provider() - .key_provider - .load_private_key(&private_key_bytes) - .expect("Failed to parse client private key"); + // Load or use the private key + let private_key = match &private_key { + crate::DtlsCertificatePrivateKey::Pkcs8(key_der) => { + if key_der.is_empty() { + panic!("Client private key cannot be empty"); + } + config + .crypto_provider() + .key_provider + .load_private_key(key_der) + .expect("Failed to parse client private key") + } + crate::DtlsCertificatePrivateKey::SigningKey(key) => (**key).clone_box(), + }; CryptoContext { config, diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index fb3074a..25f45a7 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -28,7 +28,7 @@ use crate::dtls13::message::Sequence; use crate::timer::ExponentialBackoff; use crate::types::{HashAlgorithm, Random}; use crate::window::ReplayWindow; -use crate::{Config, DtlsCertificate, Error, Output, SeededRng}; +use crate::{Config, DtlsCertificate, DtlsCertificatePrivateKey, Error, Output, SeededRng}; const MAX_DEFRAGMENT_PACKETS: usize = 50; @@ -196,11 +196,14 @@ impl Engine { let flight_backoff = ExponentialBackoff::new(config.flight_start_rto(), config.flight_retries(), &mut rng); - let signing_key = config - .crypto_provider() - .key_provider - .load_private_key(&certificate.private_key) - .expect("Failed to load private key"); + let signing_key = match &certificate.private_key { + DtlsCertificatePrivateKey::Pkcs8(key_der) => config + .crypto_provider() + .key_provider + .load_private_key(key_der) + .expect("Failed to load private key"), + DtlsCertificatePrivateKey::SigningKey(key) => (**key).clone_box(), + }; let aead_encryption_threshold = jittered_aead_threshold(config.aead_encryption_limit(), &mut rng); diff --git a/src/lib.rs b/src/lib.rs index ecbe4df..cb802a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -199,7 +199,7 @@ pub mod certificate; pub mod crypto; -pub use crypto::{KeyingMaterial, SrtpProfile}; +pub use crypto::{KeyingMaterial, SigningKey, SrtpProfile}; mod timer; @@ -211,15 +211,99 @@ pub(crate) use rng::SeededRng; pub struct DtlsCertificate { /// Certificate in DER format. pub certificate: Vec, - /// Private key in DER format. - pub private_key: Vec, + /// Private key (either PKCS8 DER bytes or a signing key trait object). + pub private_key: DtlsCertificatePrivateKey, +} + +/// Private key representation for DTLS certificates. +/// +/// Supports either PKCS8 DER-encoded private key bytes or a pre-loaded +/// signing key trait object. The latter is useful for hardware security +/// modules or keystores where the private key material cannot be exported. +/// +/// When using `SigningKey`, the Arc wrapper enables drop tracking: when the +/// last reference is dropped, implementers can clean up native crypto resources +/// via the SigningKey's Drop implementation. +/// +/// # Examples +/// +/// Using PKCS8 DER bytes (the default): +/// ```ignore +/// let cert = DtlsCertificate { +/// certificate: cert_der, +/// private_key: DtlsCertificatePrivateKey::Pkcs8(key_der), +/// }; +/// ``` +/// +/// Using a custom signing key (e.g., from a hardware security module): +/// ```ignore +/// struct MyHsmSigningKey { /* ... */ } +/// +/// impl SigningKey for MyHsmSigningKey { +/// fn sign(&mut self, data: &[u8], out: &mut Buf) -> Result<(), String> { +/// // Call HSM to sign data +/// todo!() +/// } +/// +/// fn algorithm(&self) -> SignatureAlgorithm { +/// SignatureAlgorithm::ECDSA +/// } +/// +/// fn hash_algorithm(&self) -> HashAlgorithm { +/// HashAlgorithm::SHA256 +/// } +/// +/// fn clone_box(&self) -> Box { +/// Box::new(MyHsmSigningKey { /* clone fields */ }) +/// } +/// } +/// +/// impl Drop for MyHsmSigningKey { +/// fn drop(&mut self) { +/// // Clean up HSM resources when the last reference is dropped +/// } +/// } +/// +/// let signing_key = Arc::new(Box::new(MyHsmSigningKey { /* ... */ }) as Box); +/// let cert = DtlsCertificate { +/// certificate: cert_der, +/// private_key: DtlsCertificatePrivateKey::SigningKey(signing_key), +/// }; +/// ``` +pub enum DtlsCertificatePrivateKey { + /// Private key in PKCS8 DER format. + Pkcs8(Vec), + /// Pre-loaded signing key. Wrapped in Arc for cloning and drop tracking. + SigningKey(Arc>), +} + +impl Clone for DtlsCertificatePrivateKey { + fn clone(&self) -> Self { + match self { + Self::Pkcs8(bytes) => Self::Pkcs8(bytes.clone()), + Self::SigningKey(key) => Self::SigningKey(Arc::clone(key)), + } + } +} + +impl fmt::Debug for DtlsCertificatePrivateKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Pkcs8(bytes) => write!(f, "Pkcs8({} bytes)", bytes.len()), + Self::SigningKey(_) => write!(f, "SigningKey"), + } + } } impl fmt::Debug for DtlsCertificate { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let private_key_desc = match &self.private_key { + DtlsCertificatePrivateKey::Pkcs8(bytes) => format!("Pkcs8({})", bytes.len()), + DtlsCertificatePrivateKey::SigningKey(_) => "SigningKey".to_string(), + }; f.debug_struct("DtlsCertificate") .field("certificate", &self.certificate.len()) - .field("private_key", &self.private_key.len()) + .field("private_key", &private_key_desc) .finish() } }