11use scylla:: batch:: Batch ;
22use scylla:: batch:: BatchType ;
3+ use scylla:: client:: session:: Session ;
34use scylla:: errors:: { ExecutionError , RequestAttemptError } ;
45use scylla:: frame:: frame_errors:: BatchSerializationError ;
56use scylla:: frame:: frame_errors:: CqlRequestSerializationError ;
7+ use scylla:: prepared_statement:: PreparedStatement ;
68use scylla:: query:: Query ;
9+ use scylla:: value:: { CqlValue , MaybeUnset } ;
10+ use std:: collections:: HashMap ;
11+ use std:: string:: String ;
712
813use crate :: utils:: create_new_session_builder;
914use crate :: utils:: setup_tracing;
@@ -12,6 +17,64 @@ use crate::utils::PerformDDL;
1217
1318use assert_matches:: assert_matches;
1419
20+ const BATCH_COUNT : usize = 100 ;
21+
22+ async fn create_test_session ( table_name : & str ) -> Session {
23+ let session = create_new_session_builder ( ) . build ( ) . await . unwrap ( ) ;
24+ let ks = unique_keyspace_name ( ) ;
25+ session
26+ . ddl ( format ! (
27+ "CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}" , ks) )
28+ . await
29+ . unwrap ( ) ;
30+ session. use_keyspace ( & ks, false ) . await . unwrap ( ) ;
31+ session
32+ . ddl ( format ! (
33+ "CREATE TABLE IF NOT EXISTS {} (k0 text, k1 int, v int, PRIMARY KEY (k0, k1))" ,
34+ table_name
35+ ) )
36+ . await
37+ . unwrap ( ) ;
38+ session
39+ }
40+
41+ async fn create_counter_tables ( session : & Session ) {
42+ for & table in [ "counter1" , "counter2" , "counter3" ] . iter ( ) {
43+ session
44+ . ddl ( format ! (
45+ "CREATE TABLE {} (k0 text PRIMARY KEY, c counter)" ,
46+ table
47+ ) )
48+ . await
49+ . unwrap ( ) ;
50+ }
51+ }
52+
53+ async fn verify_batch_insert ( session : & Session , test_name : & str , count : usize ) {
54+ let select_query = format ! ( "SELECT k0, k1, v FROM {} WHERE k0 = ?" , test_name) ;
55+ let query_result = session
56+ . query_unpaged ( select_query, ( test_name, ) )
57+ . await
58+ . unwrap ( )
59+ . into_rows_result ( )
60+ . unwrap ( ) ;
61+ let rows: Vec < ( String , i32 , i32 ) > = query_result
62+ . rows :: < ( String , i32 , i32 ) > ( )
63+ . unwrap ( )
64+ . map ( |r| r. unwrap ( ) )
65+ . collect ( ) ;
66+ assert_eq ! ( rows. len( ) , count) ;
67+ for ( k0, k1, v) in rows {
68+ assert_eq ! ( k0, test_name) ;
69+ assert_eq ! ( v, k1 + 1 ) ;
70+ }
71+ }
72+
73+ async fn prepare_insert_statement ( session : & Session , table : & str ) -> PreparedStatement {
74+ let query_str = format ! ( "INSERT INTO {} (k0, k1, v) VALUES (?, ?, ?)" , table) ;
75+ session. prepare ( Query :: new ( query_str) ) . await . unwrap ( )
76+ }
77+
1578#[ tokio:: test]
1679#[ ntest:: timeout( 60000 ) ]
1780async fn batch_statements_and_values_mismatch_detected ( ) {
@@ -76,3 +139,310 @@ async fn batch_statements_and_values_mismatch_detected() {
76139 )
77140 }
78141}
142+
143+ #[ tokio:: test]
144+ async fn test_batch_of_simple_statements ( ) {
145+ setup_tracing ( ) ;
146+ let test_name = String :: from ( "test_batch_simple_statements" ) ;
147+ let session = create_test_session ( & test_name) . await ;
148+
149+ let mut batch = Batch :: new ( BatchType :: Unlogged ) ;
150+ for i in 0 ..BATCH_COUNT {
151+ let simple_statement = Query :: new ( format ! (
152+ "INSERT INTO {} (k0, k1, v) VALUES ('{}', {}, {})" ,
153+ & test_name,
154+ & test_name,
155+ i,
156+ i + 1
157+ ) ) ;
158+ batch. append_statement ( simple_statement) ;
159+ }
160+ session. batch ( & batch, vec ! [ ( ) ; BATCH_COUNT ] ) . await . unwrap ( ) ;
161+ verify_batch_insert ( & session, & test_name, BATCH_COUNT ) . await ;
162+ }
163+
164+ #[ tokio:: test]
165+ async fn test_batch_of_bound_statements ( ) {
166+ setup_tracing ( ) ;
167+ let test_name = String :: from ( "test_batch_bound_statements" ) ;
168+ let session = create_test_session ( & test_name) . await ;
169+
170+ let prepared = prepare_insert_statement ( & session, & test_name) . await ;
171+ let mut batch = Batch :: new ( BatchType :: Unlogged ) ;
172+ let mut batch_values: Vec < _ > = Vec :: with_capacity ( BATCH_COUNT ) ;
173+ for i in 0 ..BATCH_COUNT as i32 {
174+ batch. append_statement ( prepared. clone ( ) ) ;
175+ batch_values. push ( ( test_name. as_str ( ) , i, i + 1 ) ) ;
176+ }
177+ session. batch ( & batch, batch_values) . await . unwrap ( ) ;
178+ verify_batch_insert ( & session, & test_name, BATCH_COUNT ) . await ;
179+ }
180+
181+ #[ tokio:: test]
182+ async fn test_prepared_batch ( ) {
183+ setup_tracing ( ) ;
184+ let test_name = String :: from ( "test_prepared_batch" ) ;
185+ let session = create_test_session ( & test_name) . await ;
186+
187+ let mut batch = Batch :: new ( BatchType :: Unlogged ) ;
188+ let mut batch_values = Vec :: with_capacity ( BATCH_COUNT ) ;
189+ let query_str = format ! ( "INSERT INTO {} (k0, k1, v) VALUES (?, ?, ?)" , & test_name) ;
190+ for i in 0 ..BATCH_COUNT as i32 {
191+ batch. append_statement ( Query :: new ( query_str. clone ( ) ) ) ;
192+ batch_values. push ( ( & test_name, i, i + 1 ) ) ;
193+ }
194+ let prepared_batch = session. prepare_batch ( & batch) . await . unwrap ( ) ;
195+ session. batch ( & prepared_batch, batch_values) . await . unwrap ( ) ;
196+ verify_batch_insert ( & session, & test_name, BATCH_COUNT ) . await ;
197+ }
198+
199+ #[ tokio:: test]
200+ async fn test_batch_of_bound_statements_with_unset_values ( ) {
201+ setup_tracing ( ) ;
202+ let test_name = String :: from ( "test_batch_bound_statements_with_unset_values" ) ;
203+ let session = create_test_session ( & test_name) . await ;
204+
205+ let prepared = prepare_insert_statement ( & session, & test_name) . await ;
206+ let mut batch1 = Batch :: new ( BatchType :: Unlogged ) ;
207+ let mut batch1_values = Vec :: with_capacity ( BATCH_COUNT ) ;
208+ for i in 0 ..BATCH_COUNT as i32 {
209+ batch1. append_statement ( prepared. clone ( ) ) ;
210+ batch1_values. push ( ( test_name. as_str ( ) , i, i + 1 ) ) ;
211+ }
212+ session. batch ( & batch1, batch1_values) . await . unwrap ( ) ;
213+
214+ // Update v to (k1 + 2), but for every 20th row leave v unset.
215+ let mut batch2 = Batch :: new ( BatchType :: Unlogged ) ;
216+ let mut batch2_values = Vec :: with_capacity ( BATCH_COUNT ) ;
217+ for i in 0 ..BATCH_COUNT as i32 {
218+ batch2. append_statement ( prepared. clone ( ) ) ;
219+ if i % 20 == 0 {
220+ batch2_values. push ( (
221+ MaybeUnset :: Set ( & test_name) ,
222+ MaybeUnset :: Set ( i) ,
223+ MaybeUnset :: Unset ,
224+ ) ) ;
225+ } else {
226+ batch2_values. push ( (
227+ MaybeUnset :: Set ( & test_name) ,
228+ MaybeUnset :: Set ( i) ,
229+ MaybeUnset :: Set ( i + 2 ) ,
230+ ) ) ;
231+ }
232+ }
233+ session. batch ( & batch2, batch2_values) . await . unwrap ( ) ;
234+
235+ // Verify that rows with k1 % 20 == 0 retain the original value.
236+ let select_query = format ! ( "SELECT k0, k1, v FROM {} WHERE k0 = ?" , & test_name) ;
237+ let query_result = session
238+ . query_unpaged ( select_query, ( & test_name, ) )
239+ . await
240+ . unwrap ( )
241+ . into_rows_result ( )
242+ . unwrap ( ) ;
243+ let rows: Vec < ( String , i32 , i32 ) > = query_result
244+ . rows :: < ( String , i32 , i32 ) > ( )
245+ . unwrap ( )
246+ . map ( |r| r. unwrap ( ) )
247+ . collect ( ) ;
248+ assert_eq ! (
249+ rows. len( ) ,
250+ BATCH_COUNT ,
251+ "Expected {} rows, got {}" ,
252+ BATCH_COUNT ,
253+ rows. len( )
254+ ) ;
255+ for ( k0, k1, v) in rows {
256+ assert_eq ! ( k0, test_name) ;
257+ assert_eq ! ( v, if k1 % 20 == 0 { k1 + 1 } else { k1 + 2 } ) ;
258+ }
259+ }
260+
261+ #[ tokio:: test]
262+ async fn test_batch_of_bound_statements_named_variables ( ) {
263+ setup_tracing ( ) ;
264+ let test_name = String :: from ( "test_batch_bound_statements_named_variables" ) ;
265+ let session = create_test_session ( & test_name) . await ;
266+
267+ let query_str = format ! (
268+ "INSERT INTO {} (k0, k1, v) VALUES (:k0, :k1, :v)" ,
269+ & test_name
270+ ) ;
271+ let prepared = session. prepare ( query_str) . await . unwrap ( ) ;
272+
273+ let mut batch = Batch :: new ( BatchType :: Unlogged ) ;
274+ let mut batch_values = Vec :: with_capacity ( BATCH_COUNT ) ;
275+ for i in 0 ..BATCH_COUNT as i32 {
276+ batch. append_statement ( prepared. clone ( ) ) ;
277+ let mut values = HashMap :: new ( ) ;
278+ values. insert ( "k0" , CqlValue :: Text ( test_name. clone ( ) ) ) ;
279+ values. insert ( "k1" , CqlValue :: Int ( i) ) ;
280+ values. insert ( "v" , CqlValue :: Int ( i + 1 ) ) ;
281+ batch_values. push ( values) ;
282+ }
283+ session. batch ( & batch, batch_values) . await . unwrap ( ) ;
284+ verify_batch_insert ( & session, & test_name, BATCH_COUNT ) . await ;
285+ }
286+
287+ #[ tokio:: test]
288+ async fn test_batch_of_mixed_bound_and_simple_statements ( ) {
289+ setup_tracing ( ) ;
290+ let test_name = String :: from ( "test_batch_mixed_bound_and_simple_statements" ) ;
291+ let session = create_test_session ( & test_name) . await ;
292+
293+ let query_str = format ! ( "INSERT INTO {} (k0, k1, v) VALUES (?, ?, ?)" , & test_name) ;
294+ let prepared_bound = session
295+ . prepare ( Query :: new ( query_str. clone ( ) ) )
296+ . await
297+ . unwrap ( ) ;
298+
299+ let mut batch = Batch :: new ( BatchType :: Unlogged ) ;
300+ let mut batch_values = Vec :: with_capacity ( BATCH_COUNT ) ;
301+ for i in 0 ..BATCH_COUNT as i32 {
302+ if i % 2 == 1 {
303+ let simple_statement = Query :: new ( format ! (
304+ "INSERT INTO {} (k0, k1, v) VALUES ('{}', {}, {})" ,
305+ & test_name,
306+ & test_name,
307+ i,
308+ i + 1
309+ ) ) ;
310+ batch. append_statement ( simple_statement) ;
311+ batch_values. push ( vec ! [ ] ) ;
312+ } else {
313+ batch. append_statement ( prepared_bound. clone ( ) ) ;
314+ batch_values. push ( vec ! [
315+ CqlValue :: Text ( test_name. clone( ) ) ,
316+ CqlValue :: Int ( i) ,
317+ CqlValue :: Int ( i + 1 ) ,
318+ ] ) ;
319+ }
320+ }
321+ session. batch ( & batch, batch_values) . await . unwrap ( ) ;
322+ verify_batch_insert ( & session, & test_name, BATCH_COUNT ) . await ;
323+ }
324+
325+ /// TODO: Remove #[ignore] once LWTs are supported with tablets.
326+ #[ tokio:: test]
327+ #[ ignore]
328+ async fn test_cas_batch ( ) {
329+ setup_tracing ( ) ;
330+ let test_name = String :: from ( "test_cas_batch" ) ;
331+ let session = create_test_session ( & test_name) . await ;
332+
333+ let prepared = prepare_insert_statement ( & session, & test_name) . await ;
334+ let mut batch = Batch :: new ( BatchType :: Unlogged ) ;
335+ let mut batch_values = Vec :: with_capacity ( BATCH_COUNT ) ;
336+ for i in 0 ..BATCH_COUNT as i32 {
337+ batch. append_statement ( prepared. clone ( ) ) ;
338+ batch_values. push ( ( & test_name, i, i + 1 ) ) ;
339+ }
340+ let result = session. batch ( & batch, batch_values. clone ( ) ) . await . unwrap ( ) ;
341+ let ( applied, ) : ( bool , ) = result
342+ . into_rows_result ( )
343+ . unwrap ( )
344+ . first_row :: < ( bool , ) > ( )
345+ . unwrap ( ) ;
346+ assert ! ( applied, "First CAS batch should be applied" ) ;
347+
348+ verify_batch_insert ( & session, & test_name, BATCH_COUNT ) . await ;
349+
350+ let result2 = session. batch ( & batch, batch_values) . await . unwrap ( ) ;
351+ let ( applied2, ) : ( bool , ) = result2
352+ . into_rows_result ( )
353+ . unwrap ( )
354+ . first_row :: < ( bool , ) > ( )
355+ . unwrap ( ) ;
356+ assert ! ( applied2, "Second CAS batch should not be applied" ) ;
357+ }
358+
359+ /// TODO: Remove #[ignore] once counters are supported with tablets.
360+ #[ tokio:: test]
361+ #[ ignore]
362+ async fn test_counter_batch ( ) {
363+ setup_tracing ( ) ;
364+ let test_name = String :: from ( "test_counter_batch" ) ;
365+ let session = create_test_session ( & test_name) . await ;
366+ create_counter_tables ( & session) . await ;
367+
368+ let mut batch = Batch :: new ( BatchType :: Counter ) ;
369+ let mut batch_values = Vec :: with_capacity ( 3 ) ;
370+ for i in 1 ..=3 {
371+ let query_str = format ! ( "UPDATE counter{} SET c = c + ? WHERE k0 = ?" , i) ;
372+ let prepared = session. prepare ( Query :: new ( query_str) ) . await . unwrap ( ) ;
373+ batch. append_statement ( prepared) ;
374+ batch_values. push ( ( i, & test_name) ) ;
375+ }
376+ session. batch ( & batch, batch_values) . await . unwrap ( ) ;
377+
378+ for i in 1 ..=3 {
379+ let query_str = format ! ( "SELECT c FROM counter{} WHERE k0 = ?" , i) ;
380+ let query_result = session
381+ . query_unpaged ( query_str, ( & test_name, ) )
382+ . await
383+ . unwrap ( )
384+ . into_rows_result ( )
385+ . unwrap ( ) ;
386+ let row = query_result. single_row :: < ( i64 , ) > ( ) . unwrap ( ) ;
387+ let ( c, ) = row;
388+ assert_eq ! ( c, i as i64 ) ;
389+ }
390+ }
391+
392+ /// TODO: Remove #[ignore] once counters are supported with tablets.
393+ #[ tokio:: test]
394+ #[ ignore]
395+ async fn test_fail_logged_batch_with_counter_increment ( ) {
396+ setup_tracing ( ) ;
397+ let test_name = String :: from ( "test_fail_logged_batch" ) ;
398+ let session = create_test_session ( & test_name) . await ;
399+ create_counter_tables ( & session) . await ;
400+
401+ let mut batch = Batch :: new ( BatchType :: Logged ) ;
402+ let mut batch_values: Vec < _ > = Vec :: with_capacity ( 3 ) ;
403+ for i in 1 ..=3 {
404+ let query_str = format ! ( "UPDATE counter{} SET c = c + ? WHERE k0 = ?" , i) ;
405+ let prepared = session. prepare ( Query :: new ( query_str) ) . await . unwrap ( ) ;
406+ batch. append_statement ( prepared) ;
407+ batch_values. push ( ( i, & test_name) ) ;
408+ }
409+ let err = session. batch ( & batch, batch_values) . await . unwrap_err ( ) ;
410+ assert_matches ! (
411+ err,
412+ ExecutionError :: BadQuery ( _) ,
413+ "Expected a BadQuery error when using counter statements in a LOGGED batch"
414+ ) ;
415+ }
416+
417+ /// TODO: Remove #[ignore] once counters are supported with tablets.
418+ #[ tokio:: test]
419+ #[ ignore]
420+ async fn test_fail_counter_batch_with_non_counter_increment ( ) {
421+ setup_tracing ( ) ;
422+ let test_name = String :: from ( "test_fail_counter_batch" ) ;
423+ let session = create_test_session ( & test_name) . await ;
424+ create_counter_tables ( & session) . await ;
425+
426+ let mut batch = Batch :: new ( BatchType :: Counter ) ;
427+ let mut batch_values: Vec < Vec < CqlValue > > = Vec :: new ( ) ;
428+ for i in 1 ..=3 {
429+ let query_str = format ! ( "UPDATE counter{} SET c = c + ? WHERE k0 = ?" , i) ;
430+ let prepared = session. prepare ( Query :: new ( query_str) ) . await . unwrap ( ) ;
431+ batch. append_statement ( prepared) ;
432+ batch_values. push ( vec ! [ CqlValue :: Int ( i) , CqlValue :: Text ( test_name. clone( ) ) ] ) ;
433+ }
434+
435+ let prepared = prepare_insert_statement ( & session, & test_name) . await ;
436+ batch. append_statement ( prepared) ;
437+ batch_values. push ( vec ! [
438+ CqlValue :: Text ( test_name. clone( ) ) ,
439+ CqlValue :: Int ( 0 ) ,
440+ CqlValue :: Int ( 1 ) ,
441+ ] ) ;
442+ let err = session. batch ( & batch, batch_values) . await . unwrap_err ( ) ;
443+ assert_matches ! (
444+ err,
445+ ExecutionError :: BadQuery ( _) ,
446+ "Expected a BadQuery error when including a non-counter statement in a COUNTER batch"
447+ ) ;
448+ }
0 commit comments