22
33use std:: iter;
44use std:: marker:: PhantomData ;
5+ use std:: sync:: atomic;
6+ use std:: sync:: atomic:: AtomicU8 ;
57use std:: sync:: Arc ;
68use std:: time:: Instant ;
79
@@ -10,7 +12,6 @@ use fail::fail_point;
1012use futures:: prelude:: * ;
1113use log:: debug;
1214use log:: warn;
13- use tokio:: sync:: RwLock ;
1415use tokio:: time:: Duration ;
1516
1617use crate :: backoff:: Backoff ;
@@ -76,8 +77,8 @@ use crate::Value;
7677/// txn.commit().await.unwrap();
7778/// # });
7879/// ```
79- pub struct Transaction < Cod : Codec = ApiV1TxnCodec , PdC : PdClient = PdRpcClient < Cod > > {
80- status : Arc < RwLock < TransactionStatus > > ,
80+ pub struct Transaction < Cod : Codec = ApiV1TxnCodec , PdC : PdClient < Codec = Cod > = PdRpcClient < Cod > > {
81+ status : Arc < AtomicU8 > ,
8182 timestamp : Timestamp ,
8283 buffer : Buffer ,
8384 rpc : Arc < PdC > ,
@@ -99,7 +100,7 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
99100 TransactionStatus :: Active
100101 } ;
101102 Transaction {
102- status : Arc :: new ( RwLock :: new ( status) ) ,
103+ status : Arc :: new ( AtomicU8 :: new ( status as u8 ) ) ,
103104 timestamp,
104105 buffer : Buffer :: new ( options. is_pessimistic ( ) ) ,
105106 rpc,
@@ -632,15 +633,16 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
632633 /// ```
633634 pub async fn commit ( & mut self ) -> Result < Option < Timestamp > > {
634635 debug ! ( "commiting transaction" ) ;
635- {
636- let mut status = self . status . write ( ) . await ;
637- if !matches ! (
638- * status,
639- TransactionStatus :: StartedCommit | TransactionStatus :: Active
640- ) {
641- return Err ( Error :: OperationAfterCommitError ) ;
642- }
643- * status = TransactionStatus :: StartedCommit ;
636+ if !self . transit_status (
637+ |status| {
638+ matches ! (
639+ status,
640+ TransactionStatus :: StartedCommit | TransactionStatus :: Active
641+ )
642+ } ,
643+ TransactionStatus :: StartedCommit ,
644+ ) {
645+ return Err ( Error :: OperationAfterCommitError ) ;
644646 }
645647
646648 let primary_key = self . buffer . get_primary_key ( ) ;
@@ -665,8 +667,7 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
665667 . await ;
666668
667669 if res. is_ok ( ) {
668- let mut status = self . status . write ( ) . await ;
669- * status = TransactionStatus :: Committed ;
670+ self . set_status ( TransactionStatus :: Committed ) ;
670671 }
671672 res
672673 }
@@ -689,21 +690,18 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
689690 /// ```
690691 pub async fn rollback ( & mut self ) -> Result < ( ) > {
691692 debug ! ( "rolling back transaction" ) ;
692- {
693- let status = self . status . read ( ) . await ;
694- if !matches ! (
695- * status,
696- TransactionStatus :: StartedRollback
697- | TransactionStatus :: Active
698- | TransactionStatus :: StartedCommit
699- ) {
700- return Err ( Error :: OperationAfterCommitError ) ;
701- }
702- }
703-
704- {
705- let mut status = self . status . write ( ) . await ;
706- * status = TransactionStatus :: StartedRollback ;
693+ if !self . transit_status (
694+ |status| {
695+ matches ! (
696+ status,
697+ TransactionStatus :: StartedRollback
698+ | TransactionStatus :: Active
699+ | TransactionStatus :: StartedCommit
700+ )
701+ } ,
702+ TransactionStatus :: StartedRollback ,
703+ ) {
704+ return Err ( Error :: OperationAfterCommitError ) ;
707705 }
708706
709707 let primary_key = self . buffer . get_primary_key ( ) ;
@@ -721,8 +719,7 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
721719 . await ;
722720
723721 if res. is_ok ( ) {
724- let mut status = self . status . write ( ) . await ;
725- * status = TransactionStatus :: Rolledback ;
722+ self . set_status ( TransactionStatus :: Rolledback ) ;
726723 }
727724 res
728725 }
@@ -906,8 +903,7 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
906903
907904 /// Checks if the transaction can perform arbitrary operations.
908905 async fn check_allow_operation ( & self ) -> Result < ( ) > {
909- let status = self . status . read ( ) . await ;
910- match * status {
906+ match self . get_status ( ) {
911907 TransactionStatus :: ReadOnly | TransactionStatus :: Active => Ok ( ( ) ) ,
912908 TransactionStatus :: Committed
913909 | TransactionStatus :: Rolledback
@@ -946,9 +942,9 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
946942 loop {
947943 tokio:: time:: sleep ( heartbeat_interval) . await ;
948944 {
949- let status = status. read ( ) . await ;
945+ let status: TransactionStatus = status. load ( atomic :: Ordering :: Acquire ) . into ( ) ;
950946 if matches ! (
951- * status,
947+ status,
952948 TransactionStatus :: Rolledback
953949 | TransactionStatus :: Committed
954950 | TransactionStatus :: Dropped
@@ -977,16 +973,45 @@ impl<Cod: Codec, PdC: PdClient<Codec = Cod>> Transaction<Cod, PdC> {
977973 }
978974 } ) ;
979975 }
976+
977+ fn get_status ( & self ) -> TransactionStatus {
978+ self . status . load ( atomic:: Ordering :: Acquire ) . into ( )
979+ }
980+
981+ fn set_status ( & self , status : TransactionStatus ) {
982+ self . status . store ( status as u8 , atomic:: Ordering :: Release ) ;
983+ }
984+
985+ fn transit_status < F > ( & self , check_status : F , next : TransactionStatus ) -> bool
986+ where
987+ F : Fn ( TransactionStatus ) -> bool ,
988+ {
989+ let mut current = self . get_status ( ) ;
990+ while check_status ( current) {
991+ if current == next {
992+ return true ;
993+ }
994+ match self . status . compare_exchange_weak (
995+ current as u8 ,
996+ next as u8 ,
997+ atomic:: Ordering :: AcqRel ,
998+ atomic:: Ordering :: Acquire ,
999+ ) {
1000+ Ok ( _) => return true ,
1001+ Err ( x) => current = x. into ( ) ,
1002+ }
1003+ }
1004+ false
1005+ }
9801006}
9811007
982- impl < Cod : Codec , PdC : PdClient > Drop for Transaction < Cod , PdC > {
1008+ impl < Cod : Codec , PdC : PdClient < Codec = Cod > > Drop for Transaction < Cod , PdC > {
9831009 fn drop ( & mut self ) {
9841010 debug ! ( "dropping transaction" ) ;
9851011 if std:: thread:: panicking ( ) {
9861012 return ;
9871013 }
988- let mut status = futures:: executor:: block_on ( self . status . write ( ) ) ;
989- if * status == TransactionStatus :: Active {
1014+ if self . get_status ( ) == TransactionStatus :: Active {
9901015 match self . options . check_level {
9911016 CheckLevel :: Panic => {
9921017 panic ! ( "Dropping an active transaction. Consider commit or rollback it." )
@@ -998,7 +1023,7 @@ impl<Cod: Codec, PdC: PdClient> Drop for Transaction<Cod, PdC> {
9981023 CheckLevel :: None => { }
9991024 }
10001025 }
1001- * status = TransactionStatus :: Dropped ;
1026+ self . set_status ( TransactionStatus :: Dropped ) ;
10021027 }
10031028}
10041029
@@ -1432,22 +1457,38 @@ impl<PdC: PdClient> Committer<PdC> {
14321457 }
14331458}
14341459
1435- #[ derive( PartialEq , Eq ) ]
1460+ #[ derive( PartialEq , Eq , Clone , Copy ) ]
1461+ #[ repr( u8 ) ]
14361462enum TransactionStatus {
14371463 /// The transaction is read-only [`Snapshot`](super::Snapshot), no need to commit or rollback or panic on drop.
1438- ReadOnly ,
1464+ ReadOnly = 0 ,
14391465 /// The transaction have not been committed or rolled back.
1440- Active ,
1466+ Active = 1 ,
14411467 /// The transaction has committed.
1442- Committed ,
1468+ Committed = 2 ,
14431469 /// The transaction has tried to commit. Only `commit` is allowed.
1444- StartedCommit ,
1470+ StartedCommit = 3 ,
14451471 /// The transaction has rolled back.
1446- Rolledback ,
1472+ Rolledback = 4 ,
14471473 /// The transaction has tried to rollback. Only `rollback` is allowed.
1448- StartedRollback ,
1474+ StartedRollback = 5 ,
14491475 /// The transaction has been dropped.
1450- Dropped ,
1476+ Dropped = 6 ,
1477+ }
1478+
1479+ impl From < u8 > for TransactionStatus {
1480+ fn from ( num : u8 ) -> Self {
1481+ match num {
1482+ 0 => TransactionStatus :: ReadOnly ,
1483+ 1 => TransactionStatus :: Active ,
1484+ 2 => TransactionStatus :: Committed ,
1485+ 3 => TransactionStatus :: StartedCommit ,
1486+ 4 => TransactionStatus :: Rolledback ,
1487+ 5 => TransactionStatus :: StartedRollback ,
1488+ 6 => TransactionStatus :: Dropped ,
1489+ _ => panic ! ( "Unknown transaction status {}" , num) ,
1490+ }
1491+ }
14511492}
14521493
14531494#[ cfg( test) ]
0 commit comments