@@ -12,20 +12,29 @@ use scylla_cql::frame::frame_errors::{
1212} ;
1313use scylla_cql:: frame:: request;
1414use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
15- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
15+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
1616use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
17+ use thiserror:: Error ;
18+ use tracing:: Instrument ;
1719
18- use crate :: client:: execution_profile:: ExecutionProfileHandle ;
20+ use crate :: client:: execution_profile:: { ExecutionProfileHandle , ExecutionProfileInner } ;
21+ use crate :: client:: session:: { RunRequestResult , Session } ;
1922use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
23+ use crate :: network:: Connection ;
24+ use crate :: observability:: driver_tracing:: RequestSpan ;
2025use crate :: observability:: history:: HistoryListener ;
2126use crate :: policies:: load_balancing:: LoadBalancingPolicy ;
27+ use crate :: policies:: load_balancing:: RoutingInfo ;
2228use crate :: policies:: retry:: RetryPolicy ;
29+ use crate :: response:: query_result:: QueryResult ;
30+ use crate :: response:: { Coordinator , NonErrorQueryResponse , QueryResponse } ;
2331use crate :: routing:: Token ;
2432use crate :: statement:: prepared:: { PartitionKeyError , PreparedStatement } ;
2533use crate :: statement:: unprepared:: Statement ;
2634
2735use super :: StatementConfig ;
2836use super :: bound:: BoundStatement ;
37+ use super :: execute:: Execute ;
2938use super :: { Consistency , SerialConsistency } ;
3039pub use crate :: frame:: request:: batch:: BatchType ;
3140
@@ -302,8 +311,8 @@ pub struct BoundBatch {
302311 batch_type : BatchType ,
303312 pub ( crate ) buffer : Vec < u8 > ,
304313 pub ( crate ) prepared : HashMap < Bytes , PreparedStatement > ,
305- pub ( crate ) first_prepared : Option < ( PreparedStatement , Token ) > ,
306- pub ( crate ) statements_len : u16 ,
314+ first_prepared : Option < ( PreparedStatement , Token ) > ,
315+ statements_len : u16 ,
307316}
308317
309318impl BoundBatch {
@@ -315,6 +324,20 @@ impl BoundBatch {
315324 }
316325 }
317326
327+ /// Appends a new statement to the batch.
328+ pub fn append_statement < ' p , V : SerializeRow > (
329+ & mut self ,
330+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
331+ ) -> Result < ( ) , BoundBatchStatementError > {
332+ let initial_len = self . buffer . len ( ) ;
333+ self . raw_append_statement ( statement) . inspect_err ( |_| {
334+ // if we error'd at any point we should put the buffer back to its old length to not
335+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
336+ // tries with a successful statement later
337+ self . buffer . truncate ( initial_len) ;
338+ } )
339+ }
340+
318341 #[ allow( clippy:: result_large_err) ]
319342 pub ( crate ) fn from_batch (
320343 batch : & Batch ,
@@ -424,6 +447,96 @@ impl BoundBatch {
424447 self . batch_type
425448 }
426449
450+ /// Gets the number of statements that have been added to this batch so far
451+ pub fn statements_len ( & self ) -> u16 {
452+ self . statements_len
453+ }
454+
455+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
456+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
457+ // `self` to be modified if an error occured because the caller will not reset them.
458+ fn raw_append_statement < ' p , V : SerializeRow > (
459+ & mut self ,
460+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
461+ ) -> Result < ( ) , BoundBatchStatementError > {
462+ let mut statement = statement. into ( ) ;
463+ let mut first_prepared = None ;
464+
465+ if self . statements_len == 0 {
466+ // save it into a local variable for now in case a latter steps fails
467+ first_prepared = match statement {
468+ BoundBatchStatement :: Bound ( ref b) => b
469+ . token ( ) ?
470+ . map ( |token| ( b. prepared . clone ( ) . into_owned ( ) , token) ) ,
471+ BoundBatchStatement :: Prepared ( ps, values) => {
472+ let bound = ps
473+ . into_bind ( & values)
474+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
475+ let first_prepared = bound
476+ . token ( ) ?
477+ . map ( |token| ( bound. prepared . clone ( ) . into_owned ( ) , token) ) ;
478+ // we already serialized it so to avoid re-serializing it, modify the statement to a
479+ // BoundStatement
480+ statement = BoundBatchStatement :: Bound ( bound) ;
481+ first_prepared
482+ }
483+ BoundBatchStatement :: Query ( _) => None ,
484+ } ;
485+ }
486+
487+ let stmnt = match & statement {
488+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
489+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
490+ } ,
491+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
492+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
493+ } ,
494+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
495+ text : Cow :: Borrowed ( & q. contents ) ,
496+ } ,
497+ } ;
498+
499+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
500+ BoundBatchStatement :: Prepared ( ps, values) => {
501+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
502+ values. serialize ( & ctx, writer) . map ( Some )
503+ }
504+ BoundBatchStatement :: Bound ( b) => {
505+ writer. append_serialize_row ( & b. values ) ;
506+ Ok ( Some ( ( ) ) )
507+ }
508+ // query has no values
509+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
510+ } ) ?;
511+
512+ let new_statements_len = self
513+ . statements_len
514+ . checked_add ( 1 )
515+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
516+
517+ /*** at this point nothing else should be fallible as we are going to be modifying
518+ * fields that do not get reset ***/
519+
520+ self . statements_len = new_statements_len;
521+
522+ if let Some ( first_prepared) = first_prepared {
523+ self . first_prepared = Some ( first_prepared) ;
524+ }
525+
526+ let prepared = match statement {
527+ BoundBatchStatement :: Prepared ( ps, _) => Cow :: Owned ( ps) ,
528+ BoundBatchStatement :: Bound ( b) => b. prepared ,
529+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
530+ } ;
531+
532+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
533+ self . prepared
534+ . insert ( prepared. get_id ( ) . to_owned ( ) , prepared. into_owned ( ) ) ;
535+ }
536+
537+ Ok ( ( ) )
538+ }
539+
427540 #[ allow( clippy:: result_large_err) ]
428541 fn serialize_from_batch_statement < T > (
429542 & mut self ,
@@ -499,3 +612,129 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
499612 n_statements : n_statements as usize ,
500613 }
501614}
615+
616+ /// This enum represents a CQL statement, that can be part of batch and its values
617+ #[ derive( Clone ) ]
618+ #[ non_exhaustive]
619+ pub enum BoundBatchStatement < ' p , V : SerializeRow > {
620+ /// A prepared statement and its not-yet serialized values
621+ Prepared ( PreparedStatement , V ) ,
622+ /// A statement whose values have already been bound (and thus serialized)
623+ Bound ( BoundStatement < ' p > ) ,
624+ /// An unprepared statement with no values
625+ Query ( Statement ) ,
626+ }
627+
628+ impl < ' p > From < BoundStatement < ' p > > for BoundBatchStatement < ' p , ( ) > {
629+ fn from ( b : BoundStatement < ' p > ) -> Self {
630+ BoundBatchStatement :: Bound ( b)
631+ }
632+ }
633+
634+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < ' static , V > {
635+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
636+ BoundBatchStatement :: Prepared ( p, v)
637+ }
638+ }
639+
640+ impl From < Statement > for BoundBatchStatement < ' static , ( ) > {
641+ fn from ( s : Statement ) -> Self {
642+ BoundBatchStatement :: Query ( s)
643+ }
644+ }
645+
646+ impl From < & str > for BoundBatchStatement < ' static , ( ) > {
647+ fn from ( s : & str ) -> Self {
648+ BoundBatchStatement :: Query ( Statement :: from ( s) )
649+ }
650+ }
651+
652+ /// An error type returned when adding a statement to a bounded batch fails
653+ #[ non_exhaustive]
654+ #[ derive( Error , Debug , Clone ) ]
655+ pub enum BoundBatchStatementError {
656+ /// Failed to serialize the batch statement
657+ #[ error( transparent) ]
658+ Statement ( #[ from] BatchStatementSerializationError ) ,
659+ /// Failed to serialize statement's bound values.
660+ #[ error( "Failed to calculate partition key" ) ]
661+ PartitionKey ( #[ from] PartitionKeyError ) ,
662+ /// Too many statements in the batch statement.
663+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
664+ TooManyQueriesInBatchStatement ,
665+ }
666+
667+ impl Execute for BoundBatch {
668+ async fn execute ( & self , session : & Session ) -> Result < QueryResult , ExecutionError > {
669+ // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
670+ // If users batch statements by shard, they will be rewarded with full shard awareness
671+ let execution_profile = self
672+ . get_execution_profile_handle ( )
673+ . unwrap_or_else ( || session. get_default_execution_profile_handle ( ) )
674+ . access ( ) ;
675+
676+ let consistency = self
677+ . config
678+ . consistency
679+ . unwrap_or ( execution_profile. consistency ) ;
680+
681+ let serial_consistency = self
682+ . config
683+ . serial_consistency
684+ . unwrap_or ( execution_profile. serial_consistency ) ;
685+
686+ let ( table, token) = self
687+ . first_prepared
688+ . as_ref ( )
689+ . and_then ( |( ps, token) | ps. get_table_spec ( ) . map ( |table| ( table, * token) ) )
690+ . unzip ( ) ;
691+
692+ let statement_info = RoutingInfo {
693+ consistency,
694+ serial_consistency,
695+ token,
696+ table,
697+ is_confirmed_lwt : false ,
698+ } ;
699+
700+ let span = RequestSpan :: new_batch ( ) ;
701+
702+ let ( run_request_result, coordinator) : (
703+ RunRequestResult < NonErrorQueryResponse > ,
704+ Coordinator ,
705+ ) = session
706+ . run_request (
707+ statement_info,
708+ & self . config ,
709+ execution_profile,
710+ |connection : Arc < Connection > ,
711+ consistency : Consistency ,
712+ execution_profile : & ExecutionProfileInner | {
713+ let serial_consistency = self
714+ . config
715+ . serial_consistency
716+ . unwrap_or ( execution_profile. serial_consistency ) ;
717+ async move {
718+ connection
719+ . batch_with_consistency ( self , consistency, serial_consistency)
720+ . await
721+ . and_then ( QueryResponse :: into_non_error_query_response)
722+ }
723+ } ,
724+ & span,
725+ )
726+ . instrument ( span. span ( ) . clone ( ) )
727+ . await ?;
728+
729+ let result = match run_request_result {
730+ RunRequestResult :: IgnoredWriteError => QueryResult :: mock_empty ( coordinator) ,
731+ RunRequestResult :: Completed ( non_error_query_response) => {
732+ let result = non_error_query_response. into_query_result ( coordinator) ?;
733+ span. record_result_fields ( & result) ;
734+ result
735+ }
736+ } ;
737+
738+ Ok ( result)
739+ }
740+ }
0 commit comments