@@ -15,6 +15,8 @@ pub enum ParseVectorError {
1515 TooShortNumber { position : usize } ,
1616 #[ error( "Bad parsing at position {position}" ) ]
1717 BadParsing { position : usize } ,
18+ #[ error( "Index out of bounds: the dim is {dims} but the index is {index}" ) ]
19+ OutOfBound { dims : usize , index : usize } ,
1820}
1921
2022#[ inline( always) ]
8587 Ok ( vector)
8688}
8789
90+ #[ derive( PartialEq ) ]
91+ enum ParseState {
92+ Number ,
93+ Comma ,
94+ Colon ,
95+ Start ,
96+ }
97+
8898#[ inline( always) ]
8999pub fn parse_pgvector_svector < T : Zero + Clone , F > (
90100 input : & [ u8 ] ,
@@ -136,7 +146,9 @@ where
136146 } ;
137147 let mut indexes = Vec :: < u32 > :: new ( ) ;
138148 let mut values = Vec :: < T > :: new ( ) ;
139- let mut index: u32 = 0 ;
149+ let mut index: u32 = u32:: MAX ;
150+ let mut state = ParseState :: Start ;
151+
140152 for position in left + 1 ..right {
141153 let c = input[ position] ;
142154 match c {
@@ -147,15 +159,29 @@ where
147159 if token. try_push ( c) . is_err ( ) {
148160 return Err ( ParseVectorError :: TooLongNumber { position } ) ;
149161 }
162+ state = ParseState :: Number ;
150163 }
151164 b',' => {
165+ if state != ParseState :: Number {
166+ return Err ( ParseVectorError :: BadCharacter { position } ) ;
167+ }
152168 if !token. is_empty ( ) {
153169 // Safety: all bytes in `token` are ascii characters
154170 let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
155171 let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
156- indexes. push ( index) ;
157- values. push ( num) ;
172+ if index as usize >= dims {
173+ return Err ( ParseVectorError :: OutOfBound {
174+ dims,
175+ index : index as usize ,
176+ } ) ;
177+ }
178+ if !num. is_zero ( ) {
179+ indexes. push ( index) ;
180+ values. push ( num) ;
181+ }
182+ index = u32:: MAX ;
158183 token. clear ( ) ;
184+ state = ParseState :: Comma ;
159185 } else {
160186 return Err ( ParseVectorError :: TooShortNumber { position } ) ;
161187 }
@@ -168,6 +194,7 @@ where
168194 . parse :: < u32 > ( )
169195 . map_err ( |_| ParseVectorError :: BadParsing { position } ) ?;
170196 token. clear ( ) ;
197+ state = ParseState :: Colon ;
171198 } else {
172199 return Err ( ParseVectorError :: TooShortNumber { position } ) ;
173200 }
@@ -176,14 +203,90 @@ where
176203 _ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
177204 }
178205 }
206+ if state != ParseState :: Start && ( state != ParseState :: Number || index == u32:: MAX ) {
207+ return Err ( ParseVectorError :: BadCharacter { position : right } ) ;
208+ }
179209 if !token. is_empty ( ) {
180210 let position = right;
181211 // Safety: all bytes in `token` are ascii characters
182212 let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
183213 let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
184- indexes. push ( index) ;
185- values. push ( num) ;
214+ if index as usize >= dims {
215+ return Err ( ParseVectorError :: OutOfBound {
216+ dims,
217+ index : index as usize ,
218+ } ) ;
219+ }
220+ if !num. is_zero ( ) {
221+ indexes. push ( index) ;
222+ values. push ( num) ;
223+ }
186224 token. clear ( ) ;
187225 }
188- Ok ( ( indexes, values, dims) )
226+ // sort values and indexes ascend by indexes
227+ let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
228+ indices. sort_by_key ( |& i| & indexes[ i] ) ;
229+ let sortedValues: Vec < T > = indices
230+ . iter ( )
231+ . map ( |i| values. get ( * i) . unwrap ( ) . clone ( ) )
232+ . collect ( ) ;
233+ indexes. sort ( ) ;
234+ Ok ( ( indexes, sortedValues, dims) )
235+ }
236+
237+ #[ cfg( test) ]
238+ mod tests {
239+ use std:: collections:: HashMap ;
240+
241+ use base:: scalar:: F32 ;
242+
243+ use super :: * ;
244+
245+ #[ test]
246+ fn test_svector_parse_accept ( ) {
247+ let exprs: HashMap < & str , ( Vec < u32 > , Vec < F32 > , usize ) > = HashMap :: from ( [
248+ ( "{}/1" , ( vec ! [ ] , vec ! [ ] , 1 ) ) ,
249+ ( "{0:1}/1" , ( vec ! [ 0 ] , vec ! [ F32 ( 1.0 ) ] , 1 ) ) ,
250+ ( "{0:1, 1:1.5}/2" , ( vec ! [ 0 , 1 ] , vec ! [ F32 ( 1.0 ) , F32 ( 1.5 ) ] , 2 ) ) ,
251+ (
252+ "{0:+3, 2:-4.1}/3" ,
253+ ( vec ! [ 0 , 2 ] , vec ! [ F32 ( 3.0 ) , F32 ( -4.1 ) ] , 3 ) ,
254+ ) ,
255+ ( "{0:0, 1:0, 2:0}/3" , ( vec ! [ ] , vec ! [ ] , 3 ) ) ,
256+ (
257+ "{3:3, 2:2, 1:1, 0:0}/4" ,
258+ ( vec ! [ 1 , 2 , 3 ] , vec ! [ F32 ( 1.0 ) , F32 ( 2.0 ) , F32 ( 3.0 ) ] , 4 ) ,
259+ ) ,
260+ ] ) ;
261+ for ( e, ans) in exprs {
262+ let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
263+ assert ! ( ret. is_ok( ) , "at expr {e}" ) ;
264+ assert_eq ! ( ret. unwrap( ) , ans, "at expr {e}" ) ;
265+ }
266+ }
267+
268+ #[ test]
269+ fn test_svector_parse_reject ( ) {
270+ let exprs: Vec < & str > = vec ! [
271+ "{" ,
272+ "}" ,
273+ "{:" ,
274+ ":}" ,
275+ "{0:1, 1:1.5}/1" ,
276+ "{0:0, 1:0, 2:0}/2" ,
277+ "{0:1, 1:2, 2:3}" ,
278+ "{0:1, 1:2, 2:3" ,
279+ "{0:1, 1:2}/" ,
280+ "{0}/5" ,
281+ "{0:}/5" ,
282+ "{:0}/5" ,
283+ "{0:, 1:2}/5" ,
284+ "{0:1, 1}/5" ,
285+ "/2" ,
286+ ] ;
287+ for e in exprs {
288+ let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
289+ assert ! ( ret. is_err( ) , "at expr {e}" )
290+ }
291+ }
189292}
0 commit comments