Skip to content

Commit 522e0aa

Browse files
committed
Allow adding statements to a boundbatch and executing it
1 parent fada0e6 commit 522e0aa

File tree

8 files changed

+661
-373
lines changed

8 files changed

+661
-373
lines changed

scylla/src/client/session.rs

Lines changed: 20 additions & 358 deletions
Large diffs are not rendered by default.

scylla/src/network/connection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ impl Connection {
10421042
consistency,
10431043
serial_consistency,
10441044
timestamp,
1045-
statements_len: batch.statements_len,
1045+
statements_len: batch.statements_len(),
10461046
};
10471047

10481048
loop {

scylla/src/statement/batch.rs

Lines changed: 243 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,29 @@ use scylla_cql::frame::frame_errors::{
1212
};
1313
use scylla_cql::frame::request;
1414
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
15-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
15+
use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
1616
use scylla_cql::serialize::{RowWriter, SerializationError};
17+
use thiserror::Error;
18+
use tracing::Instrument;
1719

18-
use crate::client::execution_profile::ExecutionProfileHandle;
20+
use crate::client::execution_profile::{ExecutionProfileHandle, ExecutionProfileInner};
21+
use crate::client::session::{RunRequestResult, Session};
1922
use crate::errors::{BadQuery, ExecutionError, RequestAttemptError};
23+
use crate::network::Connection;
24+
use crate::observability::driver_tracing::RequestSpan;
2025
use crate::observability::history::HistoryListener;
2126
use crate::policies::load_balancing::LoadBalancingPolicy;
27+
use crate::policies::load_balancing::RoutingInfo;
2228
use crate::policies::retry::RetryPolicy;
29+
use crate::response::query_result::QueryResult;
30+
use crate::response::{Coordinator, NonErrorQueryResponse, QueryResponse};
2331
use crate::routing::Token;
2432
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
2533
use crate::statement::unprepared::Statement;
2634

2735
use super::StatementConfig;
2836
use super::bound::BoundStatement;
37+
use super::execute::Execute;
2938
use super::{Consistency, SerialConsistency};
3039
pub use crate::frame::request::batch::BatchType;
3140

@@ -302,8 +311,8 @@ pub struct BoundBatch {
302311
batch_type: BatchType,
303312
pub(crate) buffer: Vec<u8>,
304313
pub(crate) prepared: HashMap<Bytes, PreparedStatement>,
305-
pub(crate) first_prepared: Option<(PreparedStatement, Token)>,
306-
pub(crate) statements_len: u16,
314+
first_prepared: Option<(PreparedStatement, Token)>,
315+
statements_len: u16,
307316
}
308317

309318
impl BoundBatch {
@@ -315,6 +324,20 @@ impl BoundBatch {
315324
}
316325
}
317326

327+
/// Appends a new statement to the batch.
328+
pub fn append_statement<'p, V: SerializeRow>(
329+
&mut self,
330+
statement: impl Into<BoundBatchStatement<'p, V>>,
331+
) -> Result<(), BoundBatchStatementError> {
332+
let initial_len = self.buffer.len();
333+
self.raw_append_statement(statement).inspect_err(|_| {
334+
// if we error'd at any point we should put the buffer back to its old length to not
335+
// corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
336+
// tries with a successful statement later
337+
self.buffer.truncate(initial_len);
338+
})
339+
}
340+
318341
#[allow(clippy::result_large_err)]
319342
pub(crate) fn from_batch(
320343
batch: &Batch,
@@ -424,6 +447,96 @@ impl BoundBatch {
424447
self.batch_type
425448
}
426449

450+
/// Gets the number of statements that have been added to this batch so far
451+
pub fn statements_len(&self) -> u16 {
452+
self.statements_len
453+
}
454+
455+
// **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
456+
// because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
457+
// `self` to be modified if an error occured because the caller will not reset them.
458+
fn raw_append_statement<'p, V: SerializeRow>(
459+
&mut self,
460+
statement: impl Into<BoundBatchStatement<'p, V>>,
461+
) -> Result<(), BoundBatchStatementError> {
462+
let mut statement = statement.into();
463+
let mut first_prepared = None;
464+
465+
if self.statements_len == 0 {
466+
// save it into a local variable for now in case a latter steps fails
467+
first_prepared = match statement {
468+
BoundBatchStatement::Bound(ref b) => b
469+
.token()?
470+
.map(|token| (b.prepared.clone().into_owned(), token)),
471+
BoundBatchStatement::Prepared(ps, values) => {
472+
let bound = ps
473+
.into_bind(&values)
474+
.map_err(BatchStatementSerializationError::ValuesSerialiation)?;
475+
let first_prepared = bound
476+
.token()?
477+
.map(|token| (bound.prepared.clone().into_owned(), token));
478+
// we already serialized it so to avoid re-serializing it, modify the statement to a
479+
// BoundStatement
480+
statement = BoundBatchStatement::Bound(bound);
481+
first_prepared
482+
}
483+
BoundBatchStatement::Query(_) => None,
484+
};
485+
}
486+
487+
let stmnt = match &statement {
488+
BoundBatchStatement::Prepared(ps, _) => request::batch::BatchStatement::Prepared {
489+
id: Cow::Borrowed(ps.get_id()),
490+
},
491+
BoundBatchStatement::Bound(b) => request::batch::BatchStatement::Prepared {
492+
id: Cow::Borrowed(b.prepared.get_id()),
493+
},
494+
BoundBatchStatement::Query(q) => request::batch::BatchStatement::Query {
495+
text: Cow::Borrowed(&q.contents),
496+
},
497+
};
498+
499+
serialize_statement(stmnt, &mut self.buffer, |writer| match &statement {
500+
BoundBatchStatement::Prepared(ps, values) => {
501+
let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata());
502+
values.serialize(&ctx, writer).map(Some)
503+
}
504+
BoundBatchStatement::Bound(b) => {
505+
writer.append_serialize_row(&b.values);
506+
Ok(Some(()))
507+
}
508+
// query has no values
509+
BoundBatchStatement::Query(_) => Ok(Some(())),
510+
})?;
511+
512+
let new_statements_len = self
513+
.statements_len
514+
.checked_add(1)
515+
.ok_or(BoundBatchStatementError::TooManyQueriesInBatchStatement)?;
516+
517+
/*** at this point nothing else should be fallible as we are going to be modifying
518+
* fields that do not get reset ***/
519+
520+
self.statements_len = new_statements_len;
521+
522+
if let Some(first_prepared) = first_prepared {
523+
self.first_prepared = Some(first_prepared);
524+
}
525+
526+
let prepared = match statement {
527+
BoundBatchStatement::Prepared(ps, _) => Cow::Owned(ps),
528+
BoundBatchStatement::Bound(b) => b.prepared,
529+
BoundBatchStatement::Query(_) => return Ok(()),
530+
};
531+
532+
if !self.prepared.contains_key(prepared.get_id()) {
533+
self.prepared
534+
.insert(prepared.get_id().to_owned(), prepared.into_owned());
535+
}
536+
537+
Ok(())
538+
}
539+
427540
#[allow(clippy::result_large_err)]
428541
fn serialize_from_batch_statement<T>(
429542
&mut self,
@@ -499,3 +612,129 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
499612
n_statements: n_statements as usize,
500613
}
501614
}
615+
616+
/// This enum represents a CQL statement, that can be part of batch and its values
617+
#[derive(Clone)]
618+
#[non_exhaustive]
619+
pub enum BoundBatchStatement<'p, V: SerializeRow> {
620+
/// A prepared statement and its not-yet serialized values
621+
Prepared(PreparedStatement, V),
622+
/// A statement whose values have already been bound (and thus serialized)
623+
Bound(BoundStatement<'p>),
624+
/// An unprepared statement with no values
625+
Query(Statement),
626+
}
627+
628+
impl<'p> From<BoundStatement<'p>> for BoundBatchStatement<'p, ()> {
629+
fn from(b: BoundStatement<'p>) -> Self {
630+
BoundBatchStatement::Bound(b)
631+
}
632+
}
633+
634+
impl<V: SerializeRow> From<(PreparedStatement, V)> for BoundBatchStatement<'static, V> {
635+
fn from((p, v): (PreparedStatement, V)) -> Self {
636+
BoundBatchStatement::Prepared(p, v)
637+
}
638+
}
639+
640+
impl From<Statement> for BoundBatchStatement<'static, ()> {
641+
fn from(s: Statement) -> Self {
642+
BoundBatchStatement::Query(s)
643+
}
644+
}
645+
646+
impl From<&str> for BoundBatchStatement<'static, ()> {
647+
fn from(s: &str) -> Self {
648+
BoundBatchStatement::Query(Statement::from(s))
649+
}
650+
}
651+
652+
/// An error type returned when adding a statement to a bounded batch fails
653+
#[non_exhaustive]
654+
#[derive(Error, Debug, Clone)]
655+
pub enum BoundBatchStatementError {
656+
/// Failed to serialize the batch statement
657+
#[error(transparent)]
658+
Statement(#[from] BatchStatementSerializationError),
659+
/// Failed to serialize statement's bound values.
660+
#[error("Failed to calculate partition key")]
661+
PartitionKey(#[from] PartitionKeyError),
662+
/// Too many statements in the batch statement.
663+
#[error("Added statement goes over exceeded max value of 65,535")]
664+
TooManyQueriesInBatchStatement,
665+
}
666+
667+
impl Execute for BoundBatch {
668+
async fn execute(&self, session: &Session) -> Result<QueryResult, ExecutionError> {
669+
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
670+
// If users batch statements by shard, they will be rewarded with full shard awareness
671+
let execution_profile = self
672+
.get_execution_profile_handle()
673+
.unwrap_or_else(|| session.get_default_execution_profile_handle())
674+
.access();
675+
676+
let consistency = self
677+
.config
678+
.consistency
679+
.unwrap_or(execution_profile.consistency);
680+
681+
let serial_consistency = self
682+
.config
683+
.serial_consistency
684+
.unwrap_or(execution_profile.serial_consistency);
685+
686+
let (table, token) = self
687+
.first_prepared
688+
.as_ref()
689+
.and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token)))
690+
.unzip();
691+
692+
let statement_info = RoutingInfo {
693+
consistency,
694+
serial_consistency,
695+
token,
696+
table,
697+
is_confirmed_lwt: false,
698+
};
699+
700+
let span = RequestSpan::new_batch();
701+
702+
let (run_request_result, coordinator): (
703+
RunRequestResult<NonErrorQueryResponse>,
704+
Coordinator,
705+
) = session
706+
.run_request(
707+
statement_info,
708+
&self.config,
709+
execution_profile,
710+
|connection: Arc<Connection>,
711+
consistency: Consistency,
712+
execution_profile: &ExecutionProfileInner| {
713+
let serial_consistency = self
714+
.config
715+
.serial_consistency
716+
.unwrap_or(execution_profile.serial_consistency);
717+
async move {
718+
connection
719+
.batch_with_consistency(self, consistency, serial_consistency)
720+
.await
721+
.and_then(QueryResponse::into_non_error_query_response)
722+
}
723+
},
724+
&span,
725+
)
726+
.instrument(span.span().clone())
727+
.await?;
728+
729+
let result = match run_request_result {
730+
RunRequestResult::IgnoredWriteError => QueryResult::mock_empty(coordinator),
731+
RunRequestResult::Completed(non_error_query_response) => {
732+
let result = non_error_query_response.into_query_result(coordinator)?;
733+
span.record_result_fields(&result);
734+
result
735+
}
736+
};
737+
738+
Ok(result)
739+
}
740+
}

0 commit comments

Comments
 (0)