@@ -85,35 +85,54 @@ where
8585 Ok ( vector)
8686}
8787
88- #[ derive( PartialEq , Debug ) ]
88+ #[ derive( PartialEq , Debug , Clone ) ]
8989enum ParseState {
9090 Start ,
9191 LeftBracket ,
9292 Index ,
93+ Colon ,
9394 Value ,
94- Splitter ,
9595 Comma ,
96- Length ,
96+ RightBracket ,
97+ Splitter ,
98+ Dims ,
9799}
98100
99101#[ inline( always) ]
100- pub fn svector_filter_nonzero < T : Zero + Clone + PartialEq > (
102+ pub fn svector_sorted < T : Zero + Clone + PartialEq > (
101103 indexes : & [ u32 ] ,
102104 values : & [ T ] ,
103105) -> ( Vec < u32 > , Vec < T > ) {
104- let non_zero_indexes: Vec < u32 > = indexes
105- . iter ( )
106- . enumerate ( )
107- . filter ( |( i, _) | values. get ( * i) . unwrap ( ) != & T :: zero ( ) )
108- . map ( |( _, x) | * x)
109- . collect ( ) ;
110- let non_zero_values: Vec < T > = indexes
111- . iter ( )
112- . enumerate ( )
113- . filter ( |( i, _) | values. get ( * i) . unwrap ( ) != & T :: zero ( ) )
114- . map ( |( i, _) | values. get ( i) . unwrap ( ) . clone ( ) )
115- . collect ( ) ;
116- ( non_zero_indexes, non_zero_values)
106+ let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
107+ indices. sort_by_key ( |& i| & indexes[ i] ) ;
108+
109+ let mut sorted_indexes: Vec < u32 > = Vec :: with_capacity ( indexes. len ( ) ) ;
110+ let mut sorted_values: Vec < T > = Vec :: with_capacity ( indexes. len ( ) ) ;
111+ for i in indices {
112+ sorted_indexes. push ( * indexes. get ( i) . unwrap ( ) ) ;
113+ sorted_values. push ( values. get ( i) . unwrap ( ) . clone ( ) ) ;
114+ }
115+ ( sorted_indexes, sorted_values)
116+ }
117+
118+ #[ inline( always) ]
119+ pub fn svector_filter_nonzero < T : Zero + Clone + PartialEq > (
120+ indexes : & mut Vec < u32 > ,
121+ values : & mut Vec < T > ,
122+ ) {
123+ // Index must be sorted!
124+ let mut i = 0 ;
125+ let mut j = 0 ;
126+ while j < values. len ( ) {
127+ if !values[ j] . is_zero ( ) {
128+ indexes[ i] = indexes[ j] ;
129+ values[ i] = values[ j] . clone ( ) ;
130+ i += 1 ;
131+ }
132+ j += 1 ;
133+ }
134+ indexes. truncate ( i) ;
135+ values. truncate ( i) ;
117136}
118137
119138#[ inline( always) ]
@@ -133,110 +152,82 @@ where
133152 let mut values = Vec :: < T > :: new ( ) ;
134153
135154 let mut state = ParseState :: Start ;
136- for ( position, char) in input. iter ( ) . enumerate ( ) {
137- let c = * char;
138- match ( & state, c) {
139- ( _, b' ' ) => { }
140- ( ParseState :: Start , b'{' ) => {
141- state = ParseState :: LeftBracket ;
142- }
155+ for ( position, c) in input. iter ( ) . copied ( ) . enumerate ( ) {
156+ state = match ( & state, c) {
157+ ( _, b' ' ) => state,
158+ ( ParseState :: Start , b'{' ) => ParseState :: LeftBracket ,
143159 (
144160 ParseState :: LeftBracket | ParseState :: Index | ParseState :: Comma ,
145161 b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' ,
146162 ) => {
147- if token. is_empty ( ) {
148- token. push ( b'$' ) ;
149- }
150163 if token. try_push ( c) . is_err ( ) {
151164 return Err ( ParseVectorError :: TooLongNumber { position } ) ;
152165 }
153- state = ParseState :: Index ;
166+ ParseState :: Index
154167 }
155- ( ParseState :: LeftBracket | ParseState :: Comma , b'}' ) => {
156- state = ParseState :: Splitter ;
168+ ( ParseState :: Colon , b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' ) => {
169+ if token. try_push ( c) . is_err ( ) {
170+ return Err ( ParseVectorError :: TooLongNumber { position } ) ;
171+ }
172+ ParseState :: Value
157173 }
174+ ( ParseState :: LeftBracket | ParseState :: Comma , b'}' ) => ParseState :: RightBracket ,
158175 ( ParseState :: Index , b':' ) => {
159- if token. is_empty ( ) {
160- return Err ( ParseVectorError :: TooShortNumber { position } ) ;
161- }
162- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
176+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
163177 let index = s
164178 . parse :: < u32 > ( )
165179 . map_err ( |_| ParseVectorError :: BadParsing { position } ) ?;
166180 indexes. push ( index) ;
167181 token. clear ( ) ;
168- state = ParseState :: Value ;
182+ ParseState :: Colon
169183 }
170184 ( ParseState :: Value , b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' ) => {
171- if token. is_empty ( ) {
172- token. push ( b'$' ) ;
173- }
174185 if token. try_push ( c) . is_err ( ) {
175186 return Err ( ParseVectorError :: TooLongNumber { position } ) ;
176187 }
188+ ParseState :: Value
177189 }
178190 ( ParseState :: Value , b',' ) => {
179- if token. is_empty ( ) {
180- return Err ( ParseVectorError :: TooShortNumber { position } ) ;
181- }
182- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
191+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
183192 let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
184193 values. push ( num) ;
185194 token. clear ( ) ;
186- state = ParseState :: Comma ;
195+ ParseState :: Comma
187196 }
188197 ( ParseState :: Value , b'}' ) => {
189198 if token. is_empty ( ) {
190199 return Err ( ParseVectorError :: TooShortNumber { position } ) ;
191200 }
192- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
201+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
193202 let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
194203 values. push ( num) ;
195204 token. clear ( ) ;
196- state = ParseState :: Splitter ;
197- }
198- ( ParseState :: Splitter , b'/' ) => {
199- state = ParseState :: Length ;
205+ ParseState :: RightBracket
200206 }
201- ( ParseState :: Length , b'0' ..=b'9' ) => {
202- if token. is_empty ( ) {
203- token. push ( b'$' ) ;
204- }
207+ ( ParseState :: RightBracket , b'/' ) => ParseState :: Splitter ,
208+ ( ParseState :: Dims | ParseState :: Splitter , b'0' ..=b'9' ) => {
205209 if token. try_push ( c) . is_err ( ) {
206210 return Err ( ParseVectorError :: TooLongNumber { position } ) ;
207211 }
212+ ParseState :: Dims
208213 }
209214 ( _, _) => {
210215 return Err ( ParseVectorError :: BadCharacter { position } ) ;
211216 }
212217 }
213218 }
214- if state != ParseState :: Length {
219+ if state != ParseState :: Dims {
215220 return Err ( ParseVectorError :: BadParsing {
216221 position : input. len ( ) ,
217222 } ) ;
218223 }
219- if token. is_empty ( ) {
220- return Err ( ParseVectorError :: TooShortNumber {
221- position : input. len ( ) ,
222- } ) ;
223- }
224- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
224+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
225225 let dims = s
226226 . parse :: < usize > ( )
227227 . map_err ( |_| ParseVectorError :: BadParsing {
228228 position : input. len ( ) ,
229229 } ) ?;
230-
231- let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
232- indices. sort_by_key ( |& i| & indexes[ i] ) ;
233- let sorted_values: Vec < T > = indices
234- . iter ( )
235- . map ( |i| values. get ( * i) . unwrap ( ) . clone ( ) )
236- . collect ( ) ;
237- indexes. sort ( ) ;
238-
239- Ok ( ( indexes, sorted_values, dims) )
230+ Ok ( ( indexes, values, dims) )
240231}
241232
242233#[ cfg( test) ]
@@ -266,8 +257,8 @@ mod tests {
266257 (
267258 "{3:3, 2:2, 1:1, 0:0}/4" ,
268259 (
269- vec![ 0 , 1 , 2 , 3 ] ,
270- vec![ F32 ( 0 .0) , F32 ( 1 .0) , F32 ( 2 .0) , F32 ( 3 .0) ] ,
260+ vec![ 3 , 2 , 1 , 0 ] ,
261+ vec![ F32 ( 3 .0) , F32 ( 2 .0) , F32 ( 1 .0) , F32 ( 0 .0) ] ,
271262 4 ,
272263 ) ,
273264 ) ,
@@ -294,16 +285,13 @@ mod tests {
294285 "{0:1, 1:2, 2:3" ,
295286 ParseVectorError :: BadParsing { position: 14 } ,
296287 ) ,
297- (
298- "{0:1, 1:2}/" ,
299- ParseVectorError :: TooShortNumber { position: 11 } ,
300- ) ,
288+ ( "{0:1, 1:2}/" , ParseVectorError :: BadParsing { position: 11 } ) ,
301289 ( "{0}/5" , ParseVectorError :: BadCharacter { position: 2 } ) ,
302- ( "{0:}/5" , ParseVectorError :: TooShortNumber { position: 3 } ) ,
290+ ( "{0:}/5" , ParseVectorError :: BadCharacter { position: 3 } ) ,
303291 ( "{:0}/5" , ParseVectorError :: BadCharacter { position: 1 } ) ,
304292 (
305293 "{0:, 1:2}/5" ,
306- ParseVectorError :: TooShortNumber { position: 3 } ,
294+ ParseVectorError :: BadCharacter { position: 3 } ,
307295 ) ,
308296 ( "{0:1, 1}/5" , ParseVectorError :: BadCharacter { position: 7 } ) ,
309297 ( "/2" , ParseVectorError :: BadCharacter { position: 0 } ) ,
@@ -347,23 +335,33 @@ mod tests {
347335 ) ,
348336 (
349337 "{2:0, 1:0}/2" ,
350- ( vec![ 1 , 2 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
338+ ( vec![ 2 , 1 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
351339 ( vec![ ] , vec![ ] ) ,
352340 ) ,
353341 (
354342 "{2:0, 1:0, }/2" ,
355- ( vec![ 1 , 2 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
343+ ( vec![ 2 , 1 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
356344 ( vec![ ] , vec![ ] ) ,
357345 ) ,
346+ (
347+ "{3:2, 2:1, 1:0, 0:-1}/4" ,
348+ (
349+ vec![ 3 , 2 , 1 , 0 ] ,
350+ vec![ F32 ( 2.0 ) , F32 ( 1.0 ) , F32 ( 0.0 ) , F32 ( -1.0 ) ] ,
351+ 4 ,
352+ ) ,
353+ ( vec![ 0 , 2 , 3 ] , vec![ F32 ( -1.0 ) , F32 ( 1.0 ) , F32 ( 2.0 ) ] ) ,
354+ ) ,
358355 ] ;
359356 for ( e, parsed, filtered) in exprs {
360357 let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
361358 assert ! ( ret. is_ok( ) , "at expr {:?}: {:?}" , e, ret) ;
362359 assert_eq ! ( ret. unwrap( ) , parsed, "parsed at expr {:?}" , e) ;
363360
364361 let ( indexes, values, _) = parsed;
365- let nonzero = svector_filter_nonzero ( & indexes, & values) ;
366- assert_eq ! ( nonzero, filtered, "filtered at expr {:?}" , e) ;
362+ let ( mut indexes, mut values) = svector_sorted ( & indexes, & values) ;
363+ svector_filter_nonzero ( & mut indexes, & mut values) ;
364+ assert_eq ! ( ( indexes, values) , filtered, "filtered at expr {:?}" , e) ;
367365 }
368366 }
369367}
0 commit comments