@@ -423,9 +423,10 @@ static __poll_t tls_sk_poll(struct file *file, struct socket *sock,
423423 ctx = tls_sw_ctx_rx (tls_ctx );
424424 psock = sk_psock_get (sk );
425425
426- if (skb_queue_empty_lockless (& ctx -> rx_list ) &&
427- !tls_strp_msg_ready (ctx ) &&
428- sk_psock_queue_empty (psock ))
426+ if ((skb_queue_empty_lockless (& ctx -> rx_list ) &&
427+ !tls_strp_msg_ready (ctx ) &&
428+ sk_psock_queue_empty (psock )) ||
429+ READ_ONCE (ctx -> key_update_pending ))
429430 mask &= ~(EPOLLIN | EPOLLRDNORM );
430431
431432 if (psock )
@@ -612,11 +613,13 @@ static int validate_crypto_info(const struct tls_crypto_info *crypto_info,
612613static int do_tls_setsockopt_conf (struct sock * sk , sockptr_t optval ,
613614 unsigned int optlen , int tx )
614615{
615- struct tls_crypto_info * crypto_info ;
616- struct tls_crypto_info * alt_crypto_info ;
616+ struct tls_crypto_info * crypto_info , * alt_crypto_info ;
617+ struct tls_crypto_info * old_crypto_info = NULL ;
617618 struct tls_context * ctx = tls_get_ctx (sk );
618619 const struct tls_cipher_desc * cipher_desc ;
619620 union tls_crypto_context * crypto_ctx ;
621+ union tls_crypto_context tmp = {};
622+ bool update = false;
620623 int rc = 0 ;
621624 int conf ;
622625
@@ -633,17 +636,36 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
633636
634637 crypto_info = & crypto_ctx -> info ;
635638
636- /* Currently we don't support set crypto info more than one time */
637- if (TLS_CRYPTO_INFO_READY (crypto_info ))
638- return - EBUSY ;
639+ if (TLS_CRYPTO_INFO_READY (crypto_info )) {
640+ /* Currently we only support setting crypto info more
641+ * than one time for TLS 1.3
642+ */
643+ if (crypto_info -> version != TLS_1_3_VERSION ) {
644+ TLS_INC_STATS (sock_net (sk ), tx ? LINUX_MIB_TLSTXREKEYERROR
645+ : LINUX_MIB_TLSRXREKEYERROR );
646+ return - EBUSY ;
647+ }
648+
649+ update = true;
650+ old_crypto_info = crypto_info ;
651+ crypto_info = & tmp .info ;
652+ crypto_ctx = & tmp ;
653+ }
639654
640655 rc = copy_from_sockptr (crypto_info , optval , sizeof (* crypto_info ));
641656 if (rc ) {
642657 rc = - EFAULT ;
643658 goto err_crypto_info ;
644659 }
645660
646- rc = validate_crypto_info (crypto_info , alt_crypto_info );
661+ if (update ) {
662+ /* Ensure that TLS version and ciphers are not modified */
663+ if (crypto_info -> version != old_crypto_info -> version ||
664+ crypto_info -> cipher_type != old_crypto_info -> cipher_type )
665+ rc = - EINVAL ;
666+ } else {
667+ rc = validate_crypto_info (crypto_info , alt_crypto_info );
668+ }
647669 if (rc )
648670 goto err_crypto_info ;
649671
@@ -673,11 +695,17 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
673695 TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSTXDEVICE );
674696 TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSCURRTXDEVICE );
675697 } else {
676- rc = tls_set_sw_offload (sk , 1 );
698+ rc = tls_set_sw_offload (sk , 1 ,
699+ update ? crypto_info : NULL );
677700 if (rc )
678701 goto err_crypto_info ;
679- TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSTXSW );
680- TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSCURRTXSW );
702+
703+ if (update ) {
704+ TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSTXREKEYOK );
705+ } else {
706+ TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSTXSW );
707+ TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSCURRTXSW );
708+ }
681709 conf = TLS_SW ;
682710 }
683711 } else {
@@ -687,21 +715,32 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
687715 TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSRXDEVICE );
688716 TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSCURRRXDEVICE );
689717 } else {
690- rc = tls_set_sw_offload (sk , 0 );
718+ rc = tls_set_sw_offload (sk , 0 ,
719+ update ? crypto_info : NULL );
691720 if (rc )
692721 goto err_crypto_info ;
693- TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSRXSW );
694- TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSCURRRXSW );
722+
723+ if (update ) {
724+ TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSRXREKEYOK );
725+ } else {
726+ TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSRXSW );
727+ TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSCURRRXSW );
728+ }
695729 conf = TLS_SW ;
696730 }
697- tls_sw_strparser_arm (sk , ctx );
731+ if (!update )
732+ tls_sw_strparser_arm (sk , ctx );
698733 }
699734
700735 if (tx )
701736 ctx -> tx_conf = conf ;
702737 else
703738 ctx -> rx_conf = conf ;
704739 update_sk_prot (sk , ctx );
740+
741+ if (update )
742+ return 0 ;
743+
705744 if (tx ) {
706745 ctx -> sk_write_space = sk -> sk_write_space ;
707746 sk -> sk_write_space = tls_write_space ;
@@ -713,6 +752,10 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
713752 return 0 ;
714753
715754err_crypto_info :
755+ if (update ) {
756+ TLS_INC_STATS (sock_net (sk ), tx ? LINUX_MIB_TLSTXREKEYERROR
757+ : LINUX_MIB_TLSRXREKEYERROR );
758+ }
716759 memzero_explicit (crypto_ctx , sizeof (* crypto_ctx ));
717760 return rc ;
718761}
@@ -809,6 +852,11 @@ static int tls_setsockopt(struct sock *sk, int level, int optname,
809852 return do_tls_setsockopt (sk , optname , optval , optlen );
810853}
811854
855+ static int tls_disconnect (struct sock * sk , int flags )
856+ {
857+ return - EOPNOTSUPP ;
858+ }
859+
812860struct tls_context * tls_ctx_create (struct sock * sk )
813861{
814862 struct inet_connection_sock * icsk = inet_csk (sk );
@@ -904,6 +952,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
904952 prot [TLS_BASE ][TLS_BASE ] = * base ;
905953 prot [TLS_BASE ][TLS_BASE ].setsockopt = tls_setsockopt ;
906954 prot [TLS_BASE ][TLS_BASE ].getsockopt = tls_getsockopt ;
955+ prot [TLS_BASE ][TLS_BASE ].disconnect = tls_disconnect ;
907956 prot [TLS_BASE ][TLS_BASE ].close = tls_sk_proto_close ;
908957
909958 prot [TLS_SW ][TLS_BASE ] = prot [TLS_BASE ][TLS_BASE ];
0 commit comments