@@ -4,12 +4,19 @@ use crate::parser_util::*;
44use crate :: sql_types:: * ;
55use graphql_parser:: query:: * ;
66use serde:: Serialize ;
7- use std:: collections:: HashMap ;
7+ use std:: collections:: { HashMap , HashSet } ;
88use std:: hash:: Hash ;
99use std:: ops:: Deref ;
1010use std:: str:: FromStr ;
1111use std:: sync:: Arc ;
1212
13+ #[ derive( Clone , Debug ) ]
14+ pub struct OnConflictBuilder {
15+ pub constraint : Index , // Could probably get away with a name ref
16+ pub update_fields : HashSet < Arc < Column > > , // Could probably get away with a name ref
17+ pub filter : FilterBuilder ,
18+ }
19+
1320#[ derive( Clone , Debug ) ]
1421pub struct InsertBuilder {
1522 pub alias : String ,
@@ -22,6 +29,8 @@ pub struct InsertBuilder {
2229
2330 //fields
2431 pub selections : Vec < InsertSelection > ,
32+
33+ pub on_conflict : Option < OnConflictBuilder > ,
2534}
2635
2736#[ derive( Clone , Debug ) ]
@@ -176,6 +185,90 @@ where
176185 parse_node_id ( node_id_base64_encoded_json_string)
177186}
178187
188+ fn read_argument_on_conflict < ' a , T > (
189+ field : & __Field ,
190+ query_field : & graphql_parser:: query:: Field < ' a , T > ,
191+ variables : & serde_json:: Value ,
192+ variable_definitions : & Vec < VariableDefinition < ' a , T > > ,
193+ ) -> Result < Option < OnConflictBuilder > , String >
194+ where
195+ T : Text < ' a > + Eq + AsRef < str > ,
196+ {
197+ let validated: gson:: Value = read_argument (
198+ "onConflict" ,
199+ field,
200+ query_field,
201+ variables,
202+ variable_definitions,
203+ ) ?;
204+
205+ let insert_type: InsertOnConflictType = match field. get_arg ( "onConflict" ) {
206+ None => return Ok ( None ) ,
207+ Some ( x) => match x. type_ ( ) . unmodified_type ( ) {
208+ __Type:: InsertOnConflictInput ( insert_on_conflict) => insert_on_conflict,
209+ _ => return Err ( "Could not locate Insert Entity type" . to_string ( ) ) ,
210+ } ,
211+ } ;
212+
213+ let filter: FilterBuilder =
214+ read_argument_filter ( field, query_field, variables, variable_definitions) ?;
215+
216+ let on_conflict_builder = match validated {
217+ gson:: Value :: Absent | gson:: Value :: Null => None ,
218+ gson:: Value :: Object ( contents) => {
219+ let constraint = match contents
220+ . get ( "constraint" )
221+ . expect ( "OnConflict revalidation error. Expected constraint" )
222+ {
223+ gson:: Value :: String ( ix_name) => insert_type
224+ . table
225+ . indexes
226+ . iter ( )
227+ . find ( |ix| & ix. name == ix_name)
228+ . expect ( "OnConflict revalidation error. constraint: unknown constraint name" ) ,
229+ _ => {
230+ return Err (
231+ "OnConflict revalidation error. Expected constraint as String" . to_string ( ) ,
232+ )
233+ }
234+ } ;
235+
236+ let update_fields = match contents
237+ . get ( "updateFields" )
238+ . expect ( "OnConflict revalidation error. Expected updateFields" )
239+ {
240+ gson:: Value :: Array ( col_names) => {
241+ let mut update_columns: HashSet < Arc < Column > > = HashSet :: new ( ) ;
242+ for col_name in col_names {
243+ match col_name {
244+ gson:: Value :: String ( c) => {
245+ let col = insert_type. table . columns . iter ( ) . find ( |column| & column. name == c) . expect ( "OnConflict revalidation error. updateFields: unknown column name" ) ;
246+ update_columns. insert ( Arc :: clone ( col) ) ;
247+ }
248+ _ => return Err ( "OnConflict revalidation error. Expected updateFields to be column names" . to_string ( ) ) ,
249+ }
250+ }
251+ update_columns
252+ }
253+ _ => {
254+ return Err (
255+ "OnConflict revalidation error. Expected updateFields to be an array"
256+ . to_string ( ) ,
257+ )
258+ }
259+ } ;
260+
261+ Some ( OnConflictBuilder {
262+ constraint : constraint. clone ( ) ,
263+ update_fields,
264+ filter,
265+ } )
266+ }
267+ _ => return Err ( "Insert re-validation errror" . to_string ( ) ) ,
268+ } ;
269+ Ok ( on_conflict_builder)
270+ }
271+
179272fn read_argument_objects < ' a , T > (
180273 field : & __Field ,
181274 query_field : & graphql_parser:: query:: Field < ' a , T > ,
@@ -277,11 +370,14 @@ where
277370 match & type_ {
278371 __Type:: InsertResponse ( xtype) => {
279372 // Raise for disallowed arguments
280- restrict_allowed_arguments ( & [ "objects" ] , query_field) ?;
373+ restrict_allowed_arguments ( & [ "objects" , "onConflict" ] , query_field) ?;
281374
282375 let objects: Vec < InsertRowBuilder > =
283376 read_argument_objects ( field, query_field, variables, variable_definitions) ?;
284377
378+ let on_conflict: Option < OnConflictBuilder > =
379+ read_argument_on_conflict ( field, query_field, variables, variable_definitions) ?;
380+
285381 let mut builder_fields: Vec < InsertSelection > = vec ! [ ] ;
286382
287383 let selection_fields = normalize_selection_set (
@@ -324,6 +420,7 @@ where
324420 table : Arc :: clone ( & xtype. table ) ,
325421 objects,
326422 selections : builder_fields,
423+ on_conflict,
327424 } )
328425 }
329426 _ => Err ( format ! (
0 commit comments