@@ -11,15 +11,39 @@ use crate::types::{Dtls13CipherSuite, NamedGroup};
1111
1212/// Callback for resolving PSK identities to shared secrets.
1313///
14- /// Implement this trait and provide it via `ConfigBuilder::with_psk_resolver`
15- /// to enable PSK cipher suites.
14+ /// Implement this trait and provide it via [ `ConfigBuilder::with_psk_client`]
15+ /// or [`ConfigBuilder::with_psk_server`] to enable PSK cipher suites.
1616pub trait PskResolver : Send + Sync + UnwindSafe + RefUnwindSafe {
1717 /// Look up a pre-shared key by the peer's identity.
1818 ///
1919 /// Returns the shared secret bytes, or `None` if the identity is unknown.
2020 fn resolve ( & self , identity : & [ u8 ] ) -> Option < Vec < u8 > > ;
2121}
2222
23+ /// PSK configuration for a DTLS endpoint.
24+ ///
25+ /// Use [`Psk::Client`] for endpoints that initiate PSK handshakes (send identity),
26+ /// and [`Psk::Server`] for endpoints that resolve incoming identities.
27+ #[ derive( Clone ) ]
28+ pub enum Psk {
29+ /// Client-side PSK: sends `identity` during handshake, uses `resolver`
30+ /// to look up the shared secret.
31+ Client {
32+ /// The identity to send to the server.
33+ identity : Vec < u8 > ,
34+ /// Resolver for looking up shared secrets.
35+ resolver : Arc < dyn PskResolver > ,
36+ } ,
37+ /// Server-side PSK: optionally sends a `hint` to help the client choose
38+ /// an identity, uses `resolver` to look up secrets by client identity.
39+ Server {
40+ /// Optional hint sent to the client in ServerKeyExchange.
41+ hint : Option < Vec < u8 > > ,
42+ /// Resolver for looking up shared secrets.
43+ resolver : Arc < dyn PskResolver > ,
44+ } ,
45+ }
46+
2347#[ cfg( feature = "aws-lc-rs" ) ]
2448use crate :: crypto:: aws_lc_rs;
2549
@@ -45,36 +69,7 @@ pub struct Config {
4569 dtls12_cipher_suites : Option < Vec < Dtls12CipherSuite > > ,
4670 dtls13_cipher_suites : Option < Vec < Dtls13CipherSuite > > ,
4771 kx_groups : Option < Vec < NamedGroup > > ,
48- psk_identity : Option < Vec < u8 > > ,
49- psk_identity_hint : Option < Vec < u8 > > ,
50- psk_resolver : Option < Arc < dyn PskResolver > > ,
51- }
52-
53- impl fmt:: Debug for Config {
54- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
55- f. debug_struct ( "Config" )
56- . field ( "mtu" , & self . mtu )
57- . field ( "max_queue_rx" , & self . max_queue_rx )
58- . field ( "max_queue_tx" , & self . max_queue_tx )
59- . field (
60- "require_client_certificate" ,
61- & self . require_client_certificate ,
62- )
63- . field ( "use_server_cookie" , & self . use_server_cookie )
64- . field ( "flight_start_rto" , & self . flight_start_rto )
65- . field ( "flight_retries" , & self . flight_retries )
66- . field ( "handshake_timeout" , & self . handshake_timeout )
67- . field ( "crypto_provider" , & self . crypto_provider )
68- . field ( "rng_seed" , & self . rng_seed )
69- . field ( "aead_encryption_limit" , & self . aead_encryption_limit )
70- . field ( "dtls12_cipher_suites" , & self . dtls12_cipher_suites )
71- . field ( "dtls13_cipher_suites" , & self . dtls13_cipher_suites )
72- . field ( "kx_groups" , & self . kx_groups )
73- . field ( "psk_identity" , & self . psk_identity )
74- . field ( "psk_identity_hint" , & self . psk_identity_hint )
75- . field ( "psk_resolver" , & self . psk_resolver . as_ref ( ) . map ( |_| "..." ) )
76- . finish ( )
77- }
72+ psk : Option < Psk > ,
7873}
7974
8075impl Config {
@@ -95,9 +90,7 @@ impl Config {
9590 dtls12_cipher_suites : None ,
9691 dtls13_cipher_suites : None ,
9792 kx_groups : None ,
98- psk_identity : None ,
99- psk_identity_hint : None ,
100- psk_resolver : None ,
93+ psk : None ,
10194 }
10295 }
10396
@@ -195,19 +188,35 @@ impl Config {
195188 self . aead_encryption_limit
196189 }
197190
191+ /// PSK configuration, if any.
192+ pub fn psk ( & self ) -> Option < & Psk > {
193+ self . psk . as_ref ( )
194+ }
195+
198196 /// PSK identity for the client to send during handshake.
199197 pub fn psk_identity ( & self ) -> Option < & [ u8 ] > {
200- self . psk_identity . as_deref ( )
198+ match & self . psk {
199+ Some ( Psk :: Client { identity, .. } ) => Some ( identity) ,
200+ _ => None ,
201+ }
201202 }
202203
203204 /// PSK identity hint for the server to send during handshake.
204205 pub fn psk_identity_hint ( & self ) -> Option < & [ u8 ] > {
205- self . psk_identity_hint . as_deref ( )
206+ match & self . psk {
207+ Some ( Psk :: Server { hint, .. } ) => hint. as_deref ( ) ,
208+ _ => None ,
209+ }
206210 }
207211
208212 /// PSK resolver for looking up shared secrets by identity.
209213 pub fn psk_resolver ( & self ) -> Option < & dyn PskResolver > {
210- self . psk_resolver . as_deref ( )
214+ match & self . psk {
215+ Some ( Psk :: Client { resolver, .. } | Psk :: Server { resolver, .. } ) => {
216+ Some ( resolver. as_ref ( ) )
217+ }
218+ None => None ,
219+ }
211220 }
212221
213222 /// Allowed DTLS 1.2 cipher suites, filtered by the config's allow-list.
@@ -223,7 +232,7 @@ impl Config {
223232 & self ,
224233 ) -> impl Iterator < Item = & ' static dyn SupportedDtls12CipherSuite > + ' _ {
225234 let filter = self . dtls12_cipher_suites . as_ref ( ) ;
226- let has_psk = self . psk_resolver . is_some ( ) ;
235+ let has_psk = self . psk . is_some ( ) ;
227236 self . crypto_provider
228237 . supported_cipher_suites ( )
229238 . filter ( move |cs| match filter {
@@ -284,36 +293,7 @@ pub struct ConfigBuilder {
284293 dtls12_cipher_suites : Option < Vec < Dtls12CipherSuite > > ,
285294 dtls13_cipher_suites : Option < Vec < Dtls13CipherSuite > > ,
286295 kx_groups : Option < Vec < NamedGroup > > ,
287- psk_identity : Option < Vec < u8 > > ,
288- psk_identity_hint : Option < Vec < u8 > > ,
289- psk_resolver : Option < Arc < dyn PskResolver > > ,
290- }
291-
292- impl fmt:: Debug for ConfigBuilder {
293- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
294- f. debug_struct ( "ConfigBuilder" )
295- . field ( "mtu" , & self . mtu )
296- . field ( "max_queue_rx" , & self . max_queue_rx )
297- . field ( "max_queue_tx" , & self . max_queue_tx )
298- . field (
299- "require_client_certificate" ,
300- & self . require_client_certificate ,
301- )
302- . field ( "use_server_cookie" , & self . use_server_cookie )
303- . field ( "flight_start_rto" , & self . flight_start_rto )
304- . field ( "flight_retries" , & self . flight_retries )
305- . field ( "handshake_timeout" , & self . handshake_timeout )
306- . field ( "crypto_provider" , & self . crypto_provider )
307- . field ( "rng_seed" , & self . rng_seed )
308- . field ( "aead_encryption_limit" , & self . aead_encryption_limit )
309- . field ( "dtls12_cipher_suites" , & self . dtls12_cipher_suites )
310- . field ( "dtls13_cipher_suites" , & self . dtls13_cipher_suites )
311- . field ( "kx_groups" , & self . kx_groups )
312- . field ( "psk_identity" , & self . psk_identity )
313- . field ( "psk_identity_hint" , & self . psk_identity_hint )
314- . field ( "psk_resolver" , & self . psk_resolver . as_ref ( ) . map ( |_| "..." ) )
315- . finish ( )
316- }
296+ psk : Option < Psk > ,
317297}
318298
319299impl ConfigBuilder {
@@ -457,21 +437,25 @@ impl ConfigBuilder {
457437 self
458438 }
459439
460- /// Set the PSK identity for the client to send during handshake.
461- pub fn with_psk_identity ( mut self , identity : Vec < u8 > ) -> Self {
462- self . psk_identity = Some ( identity) ;
463- self
464- }
465-
466- /// Set the PSK identity hint for the server to send during handshake.
467- pub fn with_psk_identity_hint ( mut self , hint : Vec < u8 > ) -> Self {
468- self . psk_identity_hint = Some ( hint) ;
440+ /// Configure PSK for a client endpoint.
441+ ///
442+ /// The `identity` is sent to the server during the handshake.
443+ /// The `resolver` looks up the shared secret by identity.
444+ pub fn with_psk_client ( mut self , identity : Vec < u8 > , resolver : Arc < dyn PskResolver > ) -> Self {
445+ self . psk = Some ( Psk :: Client { identity, resolver } ) ;
469446 self
470447 }
471448
472- /// Set the PSK resolver for looking up shared secrets by identity.
473- pub fn with_psk_resolver ( mut self , resolver : Arc < dyn PskResolver > ) -> Self {
474- self . psk_resolver = Some ( resolver) ;
449+ /// Configure PSK for a server endpoint.
450+ ///
451+ /// The optional `hint` is sent to the client in ServerKeyExchange.
452+ /// The `resolver` looks up the shared secret by client identity.
453+ pub fn with_psk_server (
454+ mut self ,
455+ hint : Option < Vec < u8 > > ,
456+ resolver : Arc < dyn PskResolver > ,
457+ ) -> Self {
458+ self . psk = Some ( Psk :: Server { hint, resolver } ) ;
475459 self
476460 }
477461
@@ -607,9 +591,7 @@ impl ConfigBuilder {
607591 dtls12_cipher_suites : self . dtls12_cipher_suites ,
608592 dtls13_cipher_suites : self . dtls13_cipher_suites ,
609593 kx_groups : self . kx_groups ,
610- psk_identity : self . psk_identity ,
611- psk_identity_hint : self . psk_identity_hint ,
612- psk_resolver : self . psk_resolver ,
594+ psk : self . psk ,
613595 } )
614596 }
615597}
@@ -622,6 +604,73 @@ impl Default for Config {
622604 }
623605}
624606
607+ impl fmt:: Debug for Psk {
608+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
609+ match self {
610+ Psk :: Client { identity, .. } => f
611+ . debug_struct ( "Psk::Client" )
612+ . field ( "identity" , & identity)
613+ . field ( "resolver" , & "..." )
614+ . finish ( ) ,
615+ Psk :: Server { hint, .. } => f
616+ . debug_struct ( "Psk::Server" )
617+ . field ( "hint" , & hint)
618+ . field ( "resolver" , & "..." )
619+ . finish ( ) ,
620+ }
621+ }
622+ }
623+
624+ impl fmt:: Debug for Config {
625+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
626+ f. debug_struct ( "Config" )
627+ . field ( "mtu" , & self . mtu )
628+ . field ( "max_queue_rx" , & self . max_queue_rx )
629+ . field ( "max_queue_tx" , & self . max_queue_tx )
630+ . field (
631+ "require_client_certificate" ,
632+ & self . require_client_certificate ,
633+ )
634+ . field ( "use_server_cookie" , & self . use_server_cookie )
635+ . field ( "flight_start_rto" , & self . flight_start_rto )
636+ . field ( "flight_retries" , & self . flight_retries )
637+ . field ( "handshake_timeout" , & self . handshake_timeout )
638+ . field ( "crypto_provider" , & self . crypto_provider )
639+ . field ( "rng_seed" , & self . rng_seed )
640+ . field ( "aead_encryption_limit" , & self . aead_encryption_limit )
641+ . field ( "dtls12_cipher_suites" , & self . dtls12_cipher_suites )
642+ . field ( "dtls13_cipher_suites" , & self . dtls13_cipher_suites )
643+ . field ( "kx_groups" , & self . kx_groups )
644+ . field ( "psk" , & self . psk )
645+ . finish ( )
646+ }
647+ }
648+
649+ impl fmt:: Debug for ConfigBuilder {
650+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
651+ f. debug_struct ( "ConfigBuilder" )
652+ . field ( "mtu" , & self . mtu )
653+ . field ( "max_queue_rx" , & self . max_queue_rx )
654+ . field ( "max_queue_tx" , & self . max_queue_tx )
655+ . field (
656+ "require_client_certificate" ,
657+ & self . require_client_certificate ,
658+ )
659+ . field ( "use_server_cookie" , & self . use_server_cookie )
660+ . field ( "flight_start_rto" , & self . flight_start_rto )
661+ . field ( "flight_retries" , & self . flight_retries )
662+ . field ( "handshake_timeout" , & self . handshake_timeout )
663+ . field ( "crypto_provider" , & self . crypto_provider )
664+ . field ( "rng_seed" , & self . rng_seed )
665+ . field ( "aead_encryption_limit" , & self . aead_encryption_limit )
666+ . field ( "dtls12_cipher_suites" , & self . dtls12_cipher_suites )
667+ . field ( "dtls13_cipher_suites" , & self . dtls13_cipher_suites )
668+ . field ( "kx_groups" , & self . kx_groups )
669+ . field ( "psk" , & self . psk )
670+ . finish ( )
671+ }
672+ }
673+
625674#[ cfg( test) ]
626675mod tests {
627676 use super :: * ;
@@ -823,7 +872,7 @@ mod tests {
823872 }
824873
825874 let config = Config :: builder ( )
826- . with_psk_resolver ( Arc :: new ( DummyResolver ) )
875+ . with_psk_server ( None , Arc :: new ( DummyResolver ) )
827876 . build ( )
828877 . expect ( "config with PSK resolver should build" ) ;
829878 assert ! (
0 commit comments