11use num_traits:: Zero ;
22use thiserror:: Error ;
33
4- #[ derive( Debug , Error ) ]
4+ #[ derive( Debug , Error , PartialEq ) ]
55pub enum ParseVectorError {
66 #[ error( "The input string is empty." ) ]
77 EmptyString { } ,
@@ -89,12 +89,15 @@ where
8989
9090#[ derive( PartialEq ) ]
9191enum ParseState {
92- Number ,
92+ Start ,
93+ Index ,
94+ Value ,
9395 Comma ,
9496 Colon ,
95- Start ,
97+ End ,
9698}
9799
100+ // Index -> Colon -> Value -> Comma
98101#[ inline( always) ]
99102pub fn parse_pgvector_svector < T : Zero + Clone , F > (
100103 input : & [ u8 ] ,
@@ -157,24 +160,59 @@ where
157160 let mut indexes = Vec :: < u32 > :: new ( ) ;
158161 let mut values = Vec :: < T > :: new ( ) ;
159162 let mut index: u32 = u32:: MAX ;
160- let mut state = ParseState :: Start ;
161163
162- for position in left + 1 ..right {
163- let c = input[ position] ;
164- match c {
165- b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' => {
166- if token. is_empty ( ) {
167- token. push ( b'$' ) ;
168- }
169- if token. try_push ( c) . is_err ( ) {
170- return Err ( ParseVectorError :: TooLongNumber { position } ) ;
164+ let mut state = ParseState :: Start ;
165+ let mut position = left;
166+ loop {
167+ if position == right {
168+ let end_with_number = state == ParseState :: Value && !token. is_empty ( ) ;
169+ let end_with_comma = state == ParseState :: Index && token. is_empty ( ) ;
170+ if end_with_number || end_with_comma {
171+ state = ParseState :: End ;
172+ } else {
173+ return Err ( ParseVectorError :: BadCharacter { position } ) ;
174+ }
175+ }
176+ match state {
177+ ParseState :: Index => {
178+ let c = input[ position] ;
179+ match c {
180+ b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' => {
181+ if token. is_empty ( ) {
182+ token. push ( b'$' ) ;
183+ }
184+ if token. try_push ( c) . is_err ( ) {
185+ return Err ( ParseVectorError :: TooLongNumber { position } ) ;
186+ }
187+ position += 1 ;
188+ }
189+ b':' => {
190+ state = ParseState :: Colon ;
191+ }
192+ b' ' => position += 1 ,
193+ _ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
171194 }
172- state = ParseState :: Number ;
173195 }
174- b',' => {
175- if state != ParseState :: Number {
176- return Err ( ParseVectorError :: BadCharacter { position } ) ;
196+ ParseState :: Value => {
197+ let c = input[ position] ;
198+ match c {
199+ b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' => {
200+ if token. is_empty ( ) {
201+ token. push ( b'$' ) ;
202+ }
203+ if token. try_push ( c) . is_err ( ) {
204+ return Err ( ParseVectorError :: TooLongNumber { position } ) ;
205+ }
206+ position += 1 ;
207+ }
208+ b',' => {
209+ state = ParseState :: Comma ;
210+ }
211+ b' ' => position += 1 ,
212+ _ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
177213 }
214+ }
215+ e @ ( ParseState :: Comma | ParseState :: End ) => {
178216 if !token. is_empty ( ) {
179217 // Safety: all bytes in `token` are ascii characters
180218 let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
@@ -191,60 +229,44 @@ where
191229 }
192230 index = u32:: MAX ;
193231 token. clear ( ) ;
194- state = ParseState :: Comma ;
195- } else {
232+ } else if e != ParseState :: End {
196233 return Err ( ParseVectorError :: TooShortNumber { position } ) ;
197234 }
235+ if e == ParseState :: End {
236+ break ;
237+ } else {
238+ state = ParseState :: Index ;
239+ position += 1 ;
240+ }
198241 }
199- b':' => {
242+ ParseState :: Colon => {
200243 if !token. is_empty ( ) {
201244 // Safety: all bytes in `token` are ascii characters
202245 let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
203246 index = s
204247 . parse :: < u32 > ( )
205248 . map_err ( |_| ParseVectorError :: BadParsing { position } ) ?;
206249 token. clear ( ) ;
207- state = ParseState :: Colon ;
208250 } else {
209251 return Err ( ParseVectorError :: TooShortNumber { position } ) ;
210252 }
253+ state = ParseState :: Value ;
254+ position += 1 ;
255+ }
256+ ParseState :: Start => {
257+ state = ParseState :: Index ;
258+ position += 1 ;
211259 }
212- b' ' => ( ) ,
213- _ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
214- }
215- }
216- // A valid case is either
217- // - empty string: ""
218- // - end with number when a index is extracted:"1:2, 3:4"
219- if state != ParseState :: Start && ( state != ParseState :: Number || index == u32:: MAX ) {
220- return Err ( ParseVectorError :: BadCharacter { position : right } ) ;
221- }
222- if !token. is_empty ( ) {
223- let position = right;
224- // Safety: all bytes in `token` are ascii characters
225- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
226- let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
227- if index as usize >= dims {
228- return Err ( ParseVectorError :: OutOfBound {
229- dims,
230- index : index as usize ,
231- } ) ;
232- }
233- if !num. is_zero ( ) {
234- indexes. push ( index) ;
235- values. push ( num) ;
236260 }
237- token. clear ( ) ;
238261 }
239- // sort values and indexes ascend by indexes
240262 let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
241263 indices. sort_by_key ( |& i| & indexes[ i] ) ;
242- let sortedValues : Vec < T > = indices
264+ let sorted_values : Vec < T > = indices
243265 . iter ( )
244266 . map ( |i| values. get ( * i) . unwrap ( ) . clone ( ) )
245267 . collect ( ) ;
246268 indexes. sort ( ) ;
247- Ok ( ( indexes, sortedValues , dims) )
269+ Ok ( ( indexes, sorted_values , dims) )
248270}
249271
250272#[ cfg( test) ]
@@ -260,6 +282,10 @@ mod tests {
260282 let exprs: HashMap < & str , ( Vec < u32 > , Vec < F32 > , usize ) > = HashMap :: from ( [
261283 ( "{}/1" , ( vec ! [ ] , vec ! [ ] , 1 ) ) ,
262284 ( "{0:1}/1" , ( vec ! [ 0 ] , vec ! [ F32 ( 1.0 ) ] , 1 ) ) ,
285+ (
286+ "{0:1, 1:-2, }/2" ,
287+ ( vec ! [ 0 , 1 ] , vec ! [ F32 ( 1.0 ) , F32 ( -2.0 ) ] , 2 ) ,
288+ ) ,
263289 ( "{0:1, 1:1.5}/2" , ( vec ! [ 0 , 1 ] , vec ! [ F32 ( 1.0 ) , F32 ( 1.5 ) ] , 2 ) ) ,
264290 (
265291 "{0:+3, 2:-4.1}/3" ,
@@ -280,27 +306,47 @@ mod tests {
280306
281307 #[ test]
282308 fn test_svector_parse_reject ( ) {
283- let exprs: Vec < & str > = vec ! [
284- "{" ,
285- "}" ,
286- "{:" ,
287- ":}" ,
288- "{0:1, 1:1.5}/1" ,
289- "{0:0, 1:0, 2:0}/2" ,
290- "{0:1, 1:2, 2:3}" ,
291- "{0:1, 1:2, 2:3" ,
292- "{0:1, 1:2}/" ,
293- "{0}/5" ,
294- "{0:}/5" ,
295- "{:0}/5" ,
296- "{0:, 1:2}/5" ,
297- "{0:1, 1}/5" ,
298- "/2" ,
299- "{}/1/2" ,
300- ] ;
301- for e in exprs {
309+ let exprs: HashMap < & str , ParseVectorError > = HashMap :: from ( [
310+ ( "{" , ParseVectorError :: BadParentheses { character : '{' } ) ,
311+ ( "}" , ParseVectorError :: BadParentheses { character : '{' } ) ,
312+ ( "{:" , ParseVectorError :: BadCharacter { position : 1 } ) ,
313+ ( ":}" , ParseVectorError :: BadCharacter { position : 0 } ) ,
314+ (
315+ "{0:1, 1:1.5}/1" ,
316+ ParseVectorError :: OutOfBound { dims : 1 , index : 1 } ,
317+ ) ,
318+ (
319+ "{0:0, 1:0, 2:0}/2" ,
320+ ParseVectorError :: OutOfBound { dims : 2 , index : 2 } ,
321+ ) ,
322+ (
323+ "{0:1, 1:2, 2:3}" ,
324+ ParseVectorError :: BadCharacter { position : 15 } ,
325+ ) ,
326+ (
327+ "{0:1, 1:2, 2:3" ,
328+ ParseVectorError :: BadCharacter { position : 12 } ,
329+ ) ,
330+ ( "{0:1, 1:2}/" , ParseVectorError :: BadParsing { position : 10 } ) ,
331+ ( "{0}/5" , ParseVectorError :: BadCharacter { position : 2 } ) ,
332+ ( "{0:}/5" , ParseVectorError :: BadCharacter { position : 3 } ) ,
333+ ( "{:0}/5" , ParseVectorError :: TooShortNumber { position : 1 } ) ,
334+ (
335+ "{0:, 1:2}/5" ,
336+ ParseVectorError :: TooShortNumber { position : 3 } ,
337+ ) ,
338+ ( "{0:1, 1}/5" , ParseVectorError :: BadCharacter { position : 7 } ) ,
339+ ( "/2" , ParseVectorError :: BadCharacter { position : 0 } ) ,
340+ ( "{}/1/2" , ParseVectorError :: BadCharacter { position : 2 } ) ,
341+ (
342+ "{1,2,3,4}/5" ,
343+ ParseVectorError :: BadCharacter { position : 2 } ,
344+ ) ,
345+ ] ) ;
346+ for ( e, err) in exprs {
302347 let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
303- assert ! ( ret. is_err( ) , "at expr {e}" )
348+ assert ! ( ret. is_err( ) , "at expr {e}" ) ;
349+ assert_eq ! ( ret. unwrap_err( ) , err, "at expr {e}" ) ;
304350 }
305351 }
306352}
0 commit comments