@@ -554,7 +554,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
554554 case SMB_DIRECT_MSG_DATA_TRANSFER : {
555555 struct smb_direct_data_transfer * data_transfer =
556556 (struct smb_direct_data_transfer * )recvmsg -> packet ;
557- unsigned int data_length ;
557+ u32 remaining_data_length , data_offset , data_length ;
558558 int avail_recvmsg_count , receive_credits ;
559559
560560 if (wc -> byte_len <
@@ -564,15 +564,25 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
564564 return ;
565565 }
566566
567+ remaining_data_length = le32_to_cpu (data_transfer -> remaining_data_length );
567568 data_length = le32_to_cpu (data_transfer -> data_length );
568- if (data_length ) {
569- if (wc -> byte_len < sizeof (struct smb_direct_data_transfer ) +
570- (u64 )data_length ) {
571- put_recvmsg (t , recvmsg );
572- smb_direct_disconnect_rdma_connection (t );
573- return ;
574- }
569+ data_offset = le32_to_cpu (data_transfer -> data_offset );
570+ if (wc -> byte_len < data_offset ||
571+ wc -> byte_len < (u64 )data_offset + data_length ) {
572+ put_recvmsg (t , recvmsg );
573+ smb_direct_disconnect_rdma_connection (t );
574+ return ;
575+ }
576+ if (remaining_data_length > t -> max_fragmented_recv_size ||
577+ data_length > t -> max_fragmented_recv_size ||
578+ (u64 )remaining_data_length + (u64 )data_length >
579+ (u64 )t -> max_fragmented_recv_size ) {
580+ put_recvmsg (t , recvmsg );
581+ smb_direct_disconnect_rdma_connection (t );
582+ return ;
583+ }
575584
585+ if (data_length ) {
576586 if (t -> full_packet_received )
577587 recvmsg -> first_segment = true;
578588
@@ -1209,78 +1219,130 @@ static int smb_direct_writev(struct ksmbd_transport *t,
12091219 bool need_invalidate , unsigned int remote_key )
12101220{
12111221 struct smb_direct_transport * st = smb_trans_direct_transfort (t );
1212- int remaining_data_length ;
1213- int start , i , j ;
1214- int max_iov_size = st -> max_send_size -
1222+ size_t remaining_data_length ;
1223+ size_t iov_idx ;
1224+ size_t iov_ofs ;
1225+ size_t max_iov_size = st -> max_send_size -
12151226 sizeof (struct smb_direct_data_transfer );
12161227 int ret ;
1217- struct kvec vec ;
12181228 struct smb_direct_send_ctx send_ctx ;
1229+ int error = 0 ;
12191230
12201231 if (st -> status != SMB_DIRECT_CS_CONNECTED )
12211232 return - ENOTCONN ;
12221233
12231234 //FIXME: skip RFC1002 header..
1235+ if (WARN_ON_ONCE (niovs <= 1 || iov [0 ].iov_len != 4 ))
1236+ return - EINVAL ;
12241237 buflen -= 4 ;
1238+ iov_idx = 1 ;
1239+ iov_ofs = 0 ;
12251240
12261241 remaining_data_length = buflen ;
12271242 ksmbd_debug (RDMA , "Sending smb (RDMA): smb_len=%u\n" , buflen );
12281243
12291244 smb_direct_send_ctx_init (st , & send_ctx , need_invalidate , remote_key );
1230- start = i = 1 ;
1231- buflen = 0 ;
1232- while (true) {
1233- buflen += iov [i ].iov_len ;
1234- if (buflen > max_iov_size ) {
1235- if (i > start ) {
1236- remaining_data_length -=
1237- (buflen - iov [i ].iov_len );
1238- ret = smb_direct_post_send_data (st , & send_ctx ,
1239- & iov [start ], i - start ,
1240- remaining_data_length );
1241- if (ret )
1245+ while (remaining_data_length ) {
1246+ struct kvec vecs [SMB_DIRECT_MAX_SEND_SGES - 1 ]; /* minus smbdirect hdr */
1247+ size_t possible_bytes = max_iov_size ;
1248+ size_t possible_vecs ;
1249+ size_t bytes = 0 ;
1250+ size_t nvecs = 0 ;
1251+
1252+ /*
1253+ * For the last message remaining_data_length should be
1254+ * have been 0 already!
1255+ */
1256+ if (WARN_ON_ONCE (iov_idx >= niovs )) {
1257+ error = - EINVAL ;
1258+ goto done ;
1259+ }
1260+
1261+ /*
1262+ * We have 2 factors which limit the arguments we pass
1263+ * to smb_direct_post_send_data():
1264+ *
1265+ * 1. The number of supported sges for the send,
1266+ * while one is reserved for the smbdirect header.
1267+ * And we currently need one SGE per page.
1268+ * 2. The number of negotiated payload bytes per send.
1269+ */
1270+ possible_vecs = min_t (size_t , ARRAY_SIZE (vecs ), niovs - iov_idx );
1271+
1272+ while (iov_idx < niovs && possible_vecs && possible_bytes ) {
1273+ struct kvec * v = & vecs [nvecs ];
1274+ int page_count ;
1275+
1276+ v -> iov_base = ((u8 * )iov [iov_idx ].iov_base ) + iov_ofs ;
1277+ v -> iov_len = min_t (size_t ,
1278+ iov [iov_idx ].iov_len - iov_ofs ,
1279+ possible_bytes );
1280+ page_count = get_buf_page_count (v -> iov_base , v -> iov_len );
1281+ if (page_count > possible_vecs ) {
1282+ /*
1283+ * If the number of pages in the buffer
1284+ * is to much (because we currently require
1285+ * one SGE per page), we need to limit the
1286+ * length.
1287+ *
1288+ * We know possible_vecs is at least 1,
1289+ * so we always keep the first page.
1290+ *
1291+ * We need to calculate the number extra
1292+ * pages (epages) we can also keep.
1293+ *
1294+ * We calculate the number of bytes in the
1295+ * first page (fplen), this should never be
1296+ * larger than v->iov_len because page_count is
1297+ * at least 2, but adding a limitation feels
1298+ * better.
1299+ *
1300+ * Then we calculate the number of bytes (elen)
1301+ * we can keep for the extra pages.
1302+ */
1303+ size_t epages = possible_vecs - 1 ;
1304+ size_t fpofs = offset_in_page (v -> iov_base );
1305+ size_t fplen = min_t (size_t , PAGE_SIZE - fpofs , v -> iov_len );
1306+ size_t elen = min_t (size_t , v -> iov_len - fplen , epages * PAGE_SIZE );
1307+
1308+ v -> iov_len = fplen + elen ;
1309+ page_count = get_buf_page_count (v -> iov_base , v -> iov_len );
1310+ if (WARN_ON_ONCE (page_count > possible_vecs )) {
1311+ /*
1312+ * Something went wrong in the above
1313+ * logic...
1314+ */
1315+ error = - EINVAL ;
12421316 goto done ;
1243- } else {
1244- /* iov[start] is too big, break it */
1245- int nvec = (buflen + max_iov_size - 1 ) /
1246- max_iov_size ;
1247-
1248- for (j = 0 ; j < nvec ; j ++ ) {
1249- vec .iov_base =
1250- (char * )iov [start ].iov_base +
1251- j * max_iov_size ;
1252- vec .iov_len =
1253- min_t (int , max_iov_size ,
1254- buflen - max_iov_size * j );
1255- remaining_data_length -= vec .iov_len ;
1256- ret = smb_direct_post_send_data (st , & send_ctx , & vec , 1 ,
1257- remaining_data_length );
1258- if (ret )
1259- goto done ;
12601317 }
1261- i ++ ;
1262- if (i == niovs )
1263- break ;
12641318 }
1265- start = i ;
1266- buflen = 0 ;
1267- } else {
1268- i ++ ;
1269- if (i == niovs ) {
1270- /* send out all remaining vecs */
1271- remaining_data_length -= buflen ;
1272- ret = smb_direct_post_send_data (st , & send_ctx ,
1273- & iov [start ], i - start ,
1274- remaining_data_length );
1275- if (ret )
1276- goto done ;
1277- break ;
1319+ possible_vecs -= page_count ;
1320+ nvecs += 1 ;
1321+ possible_bytes -= v -> iov_len ;
1322+ bytes += v -> iov_len ;
1323+
1324+ iov_ofs += v -> iov_len ;
1325+ if (iov_ofs >= iov [iov_idx ].iov_len ) {
1326+ iov_idx += 1 ;
1327+ iov_ofs = 0 ;
12781328 }
12791329 }
1330+
1331+ remaining_data_length -= bytes ;
1332+
1333+ ret = smb_direct_post_send_data (st , & send_ctx ,
1334+ vecs , nvecs ,
1335+ remaining_data_length );
1336+ if (unlikely (ret )) {
1337+ error = ret ;
1338+ goto done ;
1339+ }
12801340 }
12811341
12821342done :
12831343 ret = smb_direct_flush_send_list (st , & send_ctx , true);
1344+ if (unlikely (!ret && error ))
1345+ ret = error ;
12841346
12851347 /*
12861348 * As an optimization, we don't wait for individual I/O to finish
@@ -1744,6 +1806,11 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
17441806 return - EINVAL ;
17451807 }
17461808
1809+ if (device -> attrs .max_send_sge < SMB_DIRECT_MAX_SEND_SGES ) {
1810+ pr_err ("warning: device max_send_sge = %d too small\n" ,
1811+ device -> attrs .max_send_sge );
1812+ return - EINVAL ;
1813+ }
17471814 if (device -> attrs .max_recv_sge < SMB_DIRECT_MAX_RECV_SGES ) {
17481815 pr_err ("warning: device max_recv_sge = %d too small\n" ,
17491816 device -> attrs .max_recv_sge );
@@ -1767,7 +1834,7 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
17671834
17681835 cap -> max_send_wr = max_send_wrs ;
17691836 cap -> max_recv_wr = t -> recv_credit_max ;
1770- cap -> max_send_sge = max_sge_per_wr ;
1837+ cap -> max_send_sge = SMB_DIRECT_MAX_SEND_SGES ;
17711838 cap -> max_recv_sge = SMB_DIRECT_MAX_RECV_SGES ;
17721839 cap -> max_inline_data = 0 ;
17731840 cap -> max_rdma_ctxs = t -> max_rw_credits ;
0 commit comments