@@ -149,36 +149,7 @@ where
149149 data : * const u8 ,
150150 data_len : usize ,
151151 ) -> Result < Self , Error > {
152- if dims == 0 {
153- return Err ( fmt_error ! (
154- ArrayError ,
155- "Zero-dimensional arrays are not supported" ,
156- ) ) ;
157- }
158- if data_len > MAX_ARRAY_BUFFER_SIZE {
159- return Err ( fmt_error ! (
160- ArrayError ,
161- "Array buffer size too big: {}, maximum: {}" ,
162- data_len,
163- MAX_ARRAY_BUFFER_SIZE
164- ) ) ;
165- }
166- let shape = slice:: from_raw_parts ( shape, dims) ;
167- let size = shape
168- . iter ( )
169- . try_fold ( std:: mem:: size_of :: < T > ( ) , |acc, & dim| {
170- acc. checked_mul ( dim)
171- . ok_or_else ( || fmt_error ! ( ArrayError , "Array buffer size too big" ) )
172- } ) ?;
173-
174- if size != data_len {
175- return Err ( fmt_error ! (
176- ArrayError ,
177- "Array buffer length mismatch (actual: {}, expected: {})" ,
178- data_len,
179- size
180- ) ) ;
181- }
152+ let shape = check_array_shape :: < T > ( dims, shape, data_len) ?;
182153 let strides = slice:: from_raw_parts ( strides, dims) ;
183154 let mut slice = None ;
184155 if data_len != 0 {
@@ -359,36 +330,7 @@ where
359330 data : * const u8 ,
360331 data_len : usize ,
361332 ) -> Result < Self , Error > {
362- if dims == 0 {
363- return Err ( fmt_error ! (
364- ArrayError ,
365- "Zero-dimensional arrays are not supported" ,
366- ) ) ;
367- }
368- if data_len > MAX_ARRAY_BUFFER_SIZE {
369- return Err ( fmt_error ! (
370- ArrayError ,
371- "Array buffer size too big: {}, maximum: {}" ,
372- data_len,
373- MAX_ARRAY_BUFFER_SIZE
374- ) ) ;
375- }
376- let shape = slice:: from_raw_parts ( shape, dims) ;
377- let size = shape
378- . iter ( )
379- . try_fold ( std:: mem:: size_of :: < T > ( ) , |acc, & dim| {
380- acc. checked_mul ( dim)
381- . ok_or_else ( || fmt_error ! ( ArrayError , "Array buffer size too big" ) )
382- } ) ?;
383-
384- if size != data_len {
385- return Err ( fmt_error ! (
386- ArrayError ,
387- "Array buffer length mismatch (actual: {}, expected: {})" ,
388- data_len,
389- size
390- ) ) ;
391- }
333+ let shape = check_array_shape :: < T > ( dims, shape, data_len) ?;
392334 let mut slice = None ;
393335 if data_len != 0 {
394336 slice = Some ( slice:: from_raw_parts ( data, data_len) ) ;
@@ -402,6 +344,45 @@ where
402344 }
403345}
404346
347+ fn check_array_shape < T > (
348+ dims : usize ,
349+ shape : * const usize ,
350+ data_len : usize ,
351+ ) -> Result < & ' static [ usize ] , Error > {
352+ if dims == 0 {
353+ return Err ( fmt_error ! (
354+ ArrayError ,
355+ "Zero-dimensional arrays are not supported" ,
356+ ) ) ;
357+ }
358+ if data_len > MAX_ARRAY_BUFFER_SIZE {
359+ return Err ( fmt_error ! (
360+ ArrayError ,
361+ "Array buffer size too big: {}, maximum: {}" ,
362+ data_len,
363+ MAX_ARRAY_BUFFER_SIZE
364+ ) ) ;
365+ }
366+ let shape = unsafe { slice:: from_raw_parts ( shape, dims) } ;
367+
368+ let size = shape
369+ . iter ( )
370+ . try_fold ( std:: mem:: size_of :: < T > ( ) , |acc, & dim| {
371+ acc. checked_mul ( dim)
372+ . ok_or_else ( || fmt_error ! ( ArrayError , "Array buffer size too big" ) )
373+ } ) ?;
374+
375+ if size != data_len {
376+ return Err ( fmt_error ! (
377+ ArrayError ,
378+ "Array buffer length mismatch (actual: {}, expected: {})" ,
379+ data_len,
380+ size
381+ ) ) ;
382+ }
383+ Ok ( shape)
384+ }
385+
405386#[ cfg( test) ]
406387mod tests {
407388 use super :: * ;
@@ -909,4 +890,78 @@ mod tests {
909890 assert_eq ! ( buf, expected) ;
910891 Ok ( ( ) )
911892 }
893+
894+ #[ test]
895+ fn test_c_major_array_basic ( ) -> TestResult {
896+ let test_data = [ 1.1 , 2.2 , 3.3 , 4.4 ] ;
897+ let array_view: CMajorArrayView < ' _ , f64 > = unsafe {
898+ CMajorArrayView :: new (
899+ 2 ,
900+ [ 2 , 2 ] . as_ptr ( ) ,
901+ test_data. as_ptr ( ) as * const u8 ,
902+ test_data. len ( ) * 8usize ,
903+ )
904+ } ?;
905+ let mut buffer = Buffer :: new ( ProtocolVersion :: V2 ) ;
906+ buffer. table ( "my_test" ) ?;
907+ buffer. column_arr ( "temperature" , & array_view) ?;
908+ let data = buffer. as_bytes ( ) ;
909+ assert_eq ! ( & data[ 0 ..7 ] , b"my_test" ) ;
910+ assert_eq ! ( & data[ 8 ..19 ] , b"temperature" ) ;
911+ assert_eq ! (
912+ & data[ 19 ..24 ] ,
913+ & [
914+ b'=' , b'=' , 14u8 , // ARRAY_BINARY_FORMAT_TYPE
915+ 10u8 , // ArrayColumnTypeTag::Double.into()
916+ 2u8
917+ ]
918+ ) ;
919+ assert_eq ! (
920+ & data[ 24 ..32 ] ,
921+ [ 2i32 . to_le_bytes( ) , 2i32 . to_le_bytes( ) ] . concat( )
922+ ) ;
923+ assert_eq ! (
924+ & data[ 32 ..64 ] ,
925+ & [
926+ 1.1f64 . to_ne_bytes( ) ,
927+ 2.2f64 . to_le_bytes( ) ,
928+ 3.3f64 . to_le_bytes( ) ,
929+ 4.4f64 . to_le_bytes( ) ,
930+ ]
931+ . concat( )
932+ ) ;
933+ Ok ( ( ) )
934+ }
935+
936+ #[ test]
937+ fn test_c_major_empty_array ( ) -> TestResult {
938+ let test_data = [ ] ;
939+ let array_view: CMajorArrayView < ' _ , f64 > = unsafe {
940+ CMajorArrayView :: new (
941+ 2 ,
942+ [ 2 , 0 ] . as_ptr ( ) ,
943+ test_data. as_ptr ( ) ,
944+ test_data. len ( ) * 8usize ,
945+ )
946+ } ?;
947+ let mut buffer = Buffer :: new ( ProtocolVersion :: V2 ) ;
948+ buffer. table ( "my_test" ) ?;
949+ buffer. column_arr ( "temperature" , & array_view) ?;
950+ let data = buffer. as_bytes ( ) ;
951+ assert_eq ! ( & data[ 0 ..7 ] , b"my_test" ) ;
952+ assert_eq ! ( & data[ 8 ..19 ] , b"temperature" ) ;
953+ assert_eq ! (
954+ & data[ 19 ..24 ] ,
955+ & [
956+ b'=' , b'=' , 14u8 , // ARRAY_BINARY_FORMAT_TYPE
957+ 10u8 , // ArrayColumnTypeTag::Double.into()
958+ 2u8
959+ ]
960+ ) ;
961+ assert_eq ! (
962+ & data[ 24 ..32 ] ,
963+ [ 2i32 . to_le_bytes( ) , 0i32 . to_le_bytes( ) ] . concat( )
964+ ) ;
965+ Ok ( ( ) )
966+ }
912967}
0 commit comments