1+ extern crate linked_hash_set;
2+ extern crate once_cell;
13extern crate openssl;
24extern crate openssl_probe;
35
6+ use self :: linked_hash_set:: LinkedHashSet ;
7+ use self :: once_cell:: sync:: OnceCell ;
48use self :: openssl:: error:: ErrorStack ;
9+ use self :: openssl:: ex_data:: Index ;
510use self :: openssl:: hash:: MessageDigest ;
611use self :: openssl:: nid:: Nid ;
712use self :: openssl:: pkcs12:: Pkcs12 ;
813use self :: openssl:: pkey:: PKey ;
914use self :: openssl:: ssl:: {
10- self , MidHandshakeSslStream , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
11- SslVerifyMode ,
15+ self , MidHandshakeSslStream , Ssl , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
16+ SslSession , SslSessionCacheMode , SslSessionRef , SslVerifyMode ,
1217} ;
1318use self :: openssl:: x509:: { X509 , store:: X509StoreBuilder , X509VerifyResult } ;
19+ use std:: borrow:: Borrow ;
20+ use std:: collections:: hash_map:: { Entry , HashMap } ;
1421use std:: error;
1522use std:: fmt;
23+ use std:: hash:: { Hash , Hasher } ;
1624use std:: io;
17- use std:: sync:: Once ;
25+ use std:: sync:: { Arc , Mutex , Once } ;
1826
1927use { Protocol , TlsAcceptorBuilder , TlsConnectorBuilder } ;
2028use self :: openssl:: pkey:: Private ;
@@ -248,6 +256,8 @@ pub struct TlsConnector {
248256 use_sni : bool ,
249257 accept_invalid_hostnames : bool ,
250258 accept_invalid_certs : bool ,
259+ session_tickets_enabled : bool ,
260+ session_cache : Arc < Mutex < SessionCache > > ,
251261}
252262
253263impl TlsConnector {
@@ -277,11 +287,37 @@ impl TlsConnector {
277287 #[ cfg( target_os = "android" ) ]
278288 load_android_root_certs ( & mut connector) ?;
279289
290+ let session_cache = Arc :: new ( Mutex :: new ( SessionCache :: new ( ) ) ) ;
291+ if builder. session_tickets_enabled {
292+ connector. set_session_cache_mode ( SslSessionCacheMode :: CLIENT ) ;
293+
294+ connector. set_new_session_callback ( {
295+ let session_cache = session_cache. clone ( ) ;
296+ move |ssl, session| {
297+ if let Some ( key) = key_index ( ) . ok ( ) . and_then ( |idx| ssl. ex_data ( idx) ) {
298+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
299+ session_cache. insert ( key. clone ( ) , session) ;
300+ }
301+ }
302+ }
303+ } ) ;
304+ connector. set_remove_session_callback ( {
305+ let session_cache = session_cache. clone ( ) ;
306+ move |_, session| {
307+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
308+ session_cache. remove ( session) ;
309+ }
310+ }
311+ } ) ;
312+ }
313+
280314 Ok ( TlsConnector {
281315 connector : connector. build ( ) ,
282316 use_sni : builder. use_sni ,
283317 accept_invalid_hostnames : builder. accept_invalid_hostnames ,
284318 accept_invalid_certs : builder. accept_invalid_certs ,
319+ session_tickets_enabled : builder. session_tickets_enabled ,
320+ session_cache,
285321 } )
286322 }
287323
@@ -297,6 +333,23 @@ impl TlsConnector {
297333 if self . accept_invalid_certs {
298334 ssl. set_verify ( SslVerifyMode :: NONE ) ;
299335 }
336+ if self . session_tickets_enabled {
337+ let key = SessionKey {
338+ host : domain. to_string ( ) ,
339+ } ;
340+
341+ if let Ok ( mut session_cache) = self . session_cache . lock ( ) {
342+ if let Some ( session) = session_cache. get ( & key) {
343+ // Note: the `unsafe`-ty here is because the `session` is required to come from the
344+ // same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
345+ // pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
346+ unsafe { ssl. set_session ( & session) ? } ;
347+ }
348+ }
349+
350+ let idx = key_index ( ) ?;
351+ ssl. set_ex_data ( idx, key) ;
352+ }
300353
301354 let s = ssl. connect ( domain, stream) ?;
302355 Ok ( TlsStream ( s) )
@@ -412,3 +465,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
412465 self . 0 . flush ( )
413466 }
414467}
468+
469+ fn key_index ( ) -> Result < Index < Ssl , SessionKey > , ErrorStack > {
470+ static IDX : OnceCell < Index < Ssl , SessionKey > > = OnceCell :: new ( ) ;
471+ IDX . get_or_try_init ( || Ssl :: new_ex_index ( ) ) . map ( |v| * v)
472+ }
473+
474+ #[ derive( Hash , PartialEq , Eq , Clone ) ]
475+ pub struct SessionKey {
476+ pub host : String ,
477+ }
478+
479+ #[ derive( Clone ) ]
480+ struct HashSession ( SslSession ) ;
481+
482+ impl PartialEq for HashSession {
483+ fn eq ( & self , other : & HashSession ) -> bool {
484+ self . 0 . id ( ) == other. 0 . id ( )
485+ }
486+ }
487+
488+ impl Eq for HashSession { }
489+
490+ impl Hash for HashSession {
491+ fn hash < H > ( & self , state : & mut H )
492+ where
493+ H : Hasher ,
494+ {
495+ self . 0 . id ( ) . hash ( state) ;
496+ }
497+ }
498+
499+ impl Borrow < [ u8 ] > for HashSession {
500+ fn borrow ( & self ) -> & [ u8 ] {
501+ self . 0 . id ( )
502+ }
503+ }
504+
505+ pub struct SessionCache {
506+ sessions : HashMap < SessionKey , LinkedHashSet < HashSession > > ,
507+ reverse : HashMap < HashSession , SessionKey > ,
508+ }
509+
510+ impl SessionCache {
511+ pub fn new ( ) -> SessionCache {
512+ SessionCache {
513+ sessions : HashMap :: new ( ) ,
514+ reverse : HashMap :: new ( ) ,
515+ }
516+ }
517+
518+ pub fn insert ( & mut self , key : SessionKey , session : SslSession ) {
519+ let session = HashSession ( session) ;
520+
521+ self . sessions
522+ . entry ( key. clone ( ) )
523+ . or_insert_with ( LinkedHashSet :: new)
524+ . insert ( session. clone ( ) ) ;
525+ self . reverse . insert ( session. clone ( ) , key) ;
526+ }
527+
528+ pub fn get ( & mut self , key : & SessionKey ) -> Option < SslSession > {
529+ let session = {
530+ let sessions = self . sessions . get_mut ( key) ?;
531+ sessions. front ( ) . cloned ( ) ?. 0
532+ } ;
533+
534+ #[ cfg( ossl111) ]
535+ {
536+ use self :: openssl:: ssl:: SslVersion ;
537+
538+ // https://tools.ietf.org/html/rfc8446#appendix-C.4
539+ // OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
540+ // that concurrent handshakes don't end up with the same session.
541+ if session. protocol_version ( ) == SslVersion :: TLS1_3 {
542+ self . remove ( & session) ;
543+ }
544+ }
545+
546+ Some ( session)
547+ }
548+
549+ pub fn remove ( & mut self , session : & SslSessionRef ) {
550+ let key = match self . reverse . remove ( session. id ( ) ) {
551+ Some ( key) => key,
552+ None => return ,
553+ } ;
554+
555+ if let Entry :: Occupied ( mut sessions) = self . sessions . entry ( key) {
556+ sessions. get_mut ( ) . remove ( session. id ( ) ) ;
557+ if sessions. get ( ) . is_empty ( ) {
558+ sessions. remove ( ) ;
559+ }
560+ }
561+ }
562+ }
563+
564+ #[ cfg( test) ]
565+ mod tests {
566+ use std:: io:: { Read , Write } ;
567+ use std:: net:: TcpStream ;
568+
569+ use crate :: TlsConnector ;
570+
571+ fn connect_and_assert ( tls : & TlsConnector , domain : & str , port : u16 , should_resume : bool ) {
572+ let s = TcpStream :: connect ( ( domain, port) ) . unwrap ( ) ;
573+ let mut stream = tls. connect ( domain, s) . unwrap ( ) ;
574+
575+ // Must write to the stream, as OpenSSL doesn't appear to call the
576+ // session callback until we do.
577+ stream. write_all ( b"GET / HTTP/1.0\r \n \r \n " ) . unwrap ( ) ;
578+ let mut result = vec ! [ ] ;
579+ stream. read_to_end ( & mut result) . unwrap ( ) ;
580+
581+ assert_eq ! ( ( stream. 0 ) . 0 . ssl( ) . session_reused( ) , should_resume) ;
582+
583+ // Must shut down properly, or OpenSSL will invalidate the session.
584+ stream. shutdown ( ) . unwrap ( ) ;
585+ }
586+
587+ #[ test]
588+ fn connect_no_session_ticket_resumption ( ) {
589+ let tls = TlsConnector :: new ( ) . unwrap ( ) ;
590+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
591+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
592+ }
593+
594+ #[ test]
595+ fn connect_session_ticket_resumption ( ) {
596+ let mut builder = TlsConnector :: builder ( ) ;
597+ builder. session_tickets_enabled ( true ) ;
598+ let tls = builder. build ( ) . unwrap ( ) ;
599+
600+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
601+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
602+ }
603+
604+ #[ test]
605+ fn connect_session_ticket_resumption_two_sites ( ) {
606+ let mut builder = TlsConnector :: builder ( ) ;
607+ builder. session_tickets_enabled ( true ) ;
608+ let tls = builder. build ( ) . unwrap ( ) ;
609+
610+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
611+ connect_and_assert ( & tls, "mozilla.org" , 443 , false ) ;
612+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
613+ connect_and_assert ( & tls, "mozilla.org" , 443 , true ) ;
614+ }
615+ }
0 commit comments