1111#define pr_fmt (fmt ) KBUILD_MODNAME ": " fmt
1212
1313#include <linux/in.h>
14+ #include <linux/in6.h>
1415#include <linux/module.h>
1516#include <linux/net.h>
1617#include <linux/ipv6.h>
@@ -191,12 +192,13 @@ static void p9_conn_cancel(struct p9_conn *m, int err)
191192
192193 spin_lock (& m -> req_lock );
193194
194- if (m -> err ) {
195+ if (READ_ONCE ( m -> err ) ) {
195196 spin_unlock (& m -> req_lock );
196197 return ;
197198 }
198199
199- m -> err = err ;
200+ WRITE_ONCE (m -> err , err );
201+ ASSERT_EXCLUSIVE_WRITER (m -> err );
200202
201203 list_for_each_entry_safe (req , rtmp , & m -> req_list , req_list ) {
202204 list_move (& req -> req_list , & cancel_list );
@@ -283,7 +285,7 @@ static void p9_read_work(struct work_struct *work)
283285
284286 m = container_of (work , struct p9_conn , rq );
285287
286- if (m -> err < 0 )
288+ if (READ_ONCE ( m -> err ) < 0 )
287289 return ;
288290
289291 p9_debug (P9_DEBUG_TRANS , "start mux %p pos %zd\n" , m , m -> rc .offset );
@@ -450,7 +452,7 @@ static void p9_write_work(struct work_struct *work)
450452
451453 m = container_of (work , struct p9_conn , wq );
452454
453- if (m -> err < 0 ) {
455+ if (READ_ONCE ( m -> err ) < 0 ) {
454456 clear_bit (Wworksched , & m -> wsched );
455457 return ;
456458 }
@@ -622,7 +624,7 @@ static void p9_poll_mux(struct p9_conn *m)
622624 __poll_t n ;
623625 int err = - ECONNRESET ;
624626
625- if (m -> err < 0 )
627+ if (READ_ONCE ( m -> err ) < 0 )
626628 return ;
627629
628630 n = p9_fd_poll (m -> client , NULL , & err );
@@ -665,6 +667,7 @@ static void p9_poll_mux(struct p9_conn *m)
665667static int p9_fd_request (struct p9_client * client , struct p9_req_t * req )
666668{
667669 __poll_t n ;
670+ int err ;
668671 struct p9_trans_fd * ts = client -> trans ;
669672 struct p9_conn * m = & ts -> conn ;
670673
@@ -673,9 +676,10 @@ static int p9_fd_request(struct p9_client *client, struct p9_req_t *req)
673676
674677 spin_lock (& m -> req_lock );
675678
676- if (m -> err < 0 ) {
679+ err = READ_ONCE (m -> err );
680+ if (err < 0 ) {
677681 spin_unlock (& m -> req_lock );
678- return m -> err ;
682+ return err ;
679683 }
680684
681685 WRITE_ONCE (req -> status , REQ_STATUS_UNSENT );
@@ -954,64 +958,55 @@ static void p9_fd_close(struct p9_client *client)
954958 kfree (ts );
955959}
956960
957- /*
958- * stolen from NFS - maybe should be made a generic function?
959- */
960- static inline int valid_ipaddr4 (const char * buf )
961- {
962- int rc , count , in [4 ];
963-
964- rc = sscanf (buf , "%d.%d.%d.%d" , & in [0 ], & in [1 ], & in [2 ], & in [3 ]);
965- if (rc != 4 )
966- return - EINVAL ;
967- for (count = 0 ; count < 4 ; count ++ ) {
968- if (in [count ] > 255 )
969- return - EINVAL ;
970- }
971- return 0 ;
972- }
973-
974961static int p9_bind_privport (struct socket * sock )
975962{
976- struct sockaddr_in cl ;
963+ struct sockaddr_storage stor = { 0 } ;
977964 int port , err = - EINVAL ;
978965
979- memset (& cl , 0 , sizeof (cl ));
980- cl .sin_family = AF_INET ;
981- cl .sin_addr .s_addr = htonl (INADDR_ANY );
966+ stor .ss_family = sock -> ops -> family ;
967+ if (stor .ss_family == AF_INET )
968+ ((struct sockaddr_in * )& stor )-> sin_addr .s_addr = htonl (INADDR_ANY );
969+ else
970+ ((struct sockaddr_in6 * )& stor )-> sin6_addr = in6addr_any ;
982971 for (port = p9_ipport_resv_max ; port >= p9_ipport_resv_min ; port -- ) {
983- cl .sin_port = htons ((ushort )port );
984- err = kernel_bind (sock , (struct sockaddr * )& cl , sizeof (cl ));
972+ if (stor .ss_family == AF_INET )
973+ ((struct sockaddr_in * )& stor )-> sin_port = htons ((ushort )port );
974+ else
975+ ((struct sockaddr_in6 * )& stor )-> sin6_port = htons ((ushort )port );
976+ err = kernel_bind (sock , (struct sockaddr * )& stor , sizeof (stor ));
985977 if (err != - EADDRINUSE )
986978 break ;
987979 }
988980 return err ;
989981}
990982
991-
992983static int
993984p9_fd_create_tcp (struct p9_client * client , const char * addr , char * args )
994985{
995986 int err ;
987+ char port_str [6 ];
996988 struct socket * csocket ;
997- struct sockaddr_in sin_server ;
989+ struct sockaddr_storage stor = { 0 } ;
998990 struct p9_fd_opts opts ;
999991
1000992 err = parse_opts (args , & opts );
1001993 if (err < 0 )
1002994 return err ;
1003995
1004- if (addr == NULL || valid_ipaddr4 ( addr ) < 0 )
996+ if (! addr )
1005997 return - EINVAL ;
1006998
999+ sprintf (port_str , "%u" , opts .port );
1000+ err = inet_pton_with_scope (current -> nsproxy -> net_ns , AF_UNSPEC , addr ,
1001+ port_str , & stor );
1002+ if (err < 0 )
1003+ return err ;
1004+
10071005 csocket = NULL ;
10081006
10091007 client -> trans_opts .tcp .port = opts .port ;
10101008 client -> trans_opts .tcp .privport = opts .privport ;
1011- sin_server .sin_family = AF_INET ;
1012- sin_server .sin_addr .s_addr = in_aton (addr );
1013- sin_server .sin_port = htons (opts .port );
1014- err = __sock_create (current -> nsproxy -> net_ns , PF_INET ,
1009+ err = __sock_create (current -> nsproxy -> net_ns , stor .ss_family ,
10151010 SOCK_STREAM , IPPROTO_TCP , & csocket , 1 );
10161011 if (err ) {
10171012 pr_err ("%s (%d): problem creating socket\n" ,
@@ -1030,8 +1025,8 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
10301025 }
10311026
10321027 err = READ_ONCE (csocket -> ops )-> connect (csocket ,
1033- (struct sockaddr * )& sin_server ,
1034- sizeof (struct sockaddr_in ), 0 );
1028+ (struct sockaddr * )& stor ,
1029+ sizeof (stor ), 0 );
10351030 if (err < 0 ) {
10361031 pr_err ("%s (%d): problem connecting socket to %s\n" ,
10371032 __func__ , task_pid_nr (current ), addr );
0 commit comments