From 261beac2fc7805943625f553e030bc3893d6d7fe Mon Sep 17 00:00:00 2001 From: Anna Hayman Date: Thu, 13 Nov 2025 14:06:25 -0500 Subject: [PATCH 1/2] Add support for configurable token bucket success reward and fractional token management Adding support for atomic float Adding support for atomic float --- .changelog/1763060740.md | 11 + .../src/client/retries/strategy/standard.rs | 5 +- .../src/client/retries/token_bucket.rs | 380 +++++++++++++++++- 3 files changed, 386 insertions(+), 10 deletions(-) create mode 100644 .changelog/1763060740.md diff --git a/.changelog/1763060740.md b/.changelog/1763060740.md new file mode 100644 index 0000000000..832bf0e5c4 --- /dev/null +++ b/.changelog/1763060740.md @@ -0,0 +1,11 @@ +--- +applies_to: +- client +authors: +- annahay +references: [] +breaking: false +new_feature: true +bug_fix: false +--- +Add support for configurable token bucket success reward and fractional token management diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs index c8700ee275..f2508c3eeb 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs @@ -210,8 +210,11 @@ impl RetryStrategy for StandardRetryStrategy { .unwrap_or(false); update_rate_limiter_if_exists(runtime_components, cfg, is_throttling_error); - // on success release any retry quota held by previous attempts + // on success release any retry quota held by previous attempts and award success tokens if !ctx.is_failed() { + // When a request succeeds, we grant an award, if present + token_bucket.reward_success(); + if let NoPermitWasReleased = self.release_retry_permit() { // In the event that there was no retry permit to release, we generate new // permits from nothing. We do this to make up for permits we had to "forget". diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs index c64c9b9e04..1beaffedda 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs @@ -5,14 +5,18 @@ use aws_smithy_types::config_bag::{Storable, StoreReplace}; use aws_smithy_types::retry::ErrorKind; +use std::fmt; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; use std::sync::Arc; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tracing::trace; const DEFAULT_CAPACITY: usize = 500; -const RETRY_COST: u32 = 5; -const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2; +const DEFAULT_RETRY_COST: u32 = 5; +const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2; const PERMIT_REGENERATION_AMOUNT: usize = 1; +const DEFAULT_SUCCESS_REWARD: f64 = 0.0; /// Token bucket used for standard and adaptive retry. #[derive(Clone, Debug)] @@ -21,6 +25,44 @@ pub struct TokenBucket { max_permits: usize, timeout_retry_cost: u32, retry_cost: u32, + success_reward: f64, + fractional_tokens: AtomicF64, +} + +pub struct AtomicF64 { + storage: AtomicU64, +} +impl AtomicF64 { + pub fn new(value: f64) -> Self { + let as_u64 = value.to_bits(); + Self { storage: AtomicU64::new(as_u64) } + } + pub fn store(&self, value: f64) { + let as_u64 = value.to_bits(); + self.storage.store(as_u64, Ordering::Relaxed) + } + pub fn load(&self) -> f64 { + let as_u64 = self.storage.load(Ordering::Relaxed); + f64::from_bits(as_u64) + } +} + +impl fmt::Debug for AtomicF64 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Use debug_struct, debug_tuple, or write! for formatting + f.debug_struct("AtomicF64") + .field("value", &self.load()) + .finish() + } +} + +impl Clone for AtomicF64 { + fn clone(&self) -> Self { + // Manually clone each field + AtomicF64 { + storage: AtomicU64::new(self.storage.load(Ordering::Relaxed)), + } + } } impl Storable for TokenBucket { @@ -32,8 +74,10 @@ impl Default for TokenBucket { Self { semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)), max_permits: DEFAULT_CAPACITY, - timeout_retry_cost: RETRY_TIMEOUT_COST, - retry_cost: RETRY_COST, + timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST, + retry_cost: DEFAULT_RETRY_COST, + success_reward: DEFAULT_SUCCESS_REWARD, + fractional_tokens: AtomicF64::new(0.0), } } } @@ -55,6 +99,8 @@ impl TokenBucket { max_permits: Semaphore::MAX_PERMITS, timeout_retry_cost: 0, retry_cost: 0, + success_reward: 0.0, + fractional_tokens: AtomicF64::new(0.0), } } @@ -77,12 +123,39 @@ impl TokenBucket { } pub(crate) fn regenerate_a_token(&self) { - if self.semaphore.available_permits() < self.max_permits { - trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket"); - self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT) + self.add_tokens(PERMIT_REGENERATION_AMOUNT); + } + + pub(crate) fn reward_success(&self) { + // Verify that fractional tokens have not become corrupted + if !self.fractional_tokens.load().is_finite() { + tracing::error!("Fractional tokens corrupted to: {}", self.fractional_tokens.load()); + // If corrupted, reset to the number of permits the bucket was created with + self.fractional_tokens.store(self.max_permits as f64); + return; + } + + if self.success_reward > 0.0 { + self.fractional_tokens.store(self.fractional_tokens.load() + self.success_reward); + } + + let full_tokens_accumulated = self.fractional_tokens.load().floor(); + if full_tokens_accumulated >= 1.0 { + self.add_tokens(full_tokens_accumulated as usize); + self.fractional_tokens.store(self.fractional_tokens.load() - full_tokens_accumulated); } } + fn add_tokens(&self, amount: usize) { + let available = self.semaphore.available_permits(); + if available >= self.max_permits { + return; + } + let tokens_to_add = amount.min(self.max_permits - available); + trace!("adding {tokens_to_add} back into the bucket"); + self.semaphore.add_permits(tokens_to_add); + } + #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))] pub(crate) fn available_permits(&self) -> usize { self.semaphore.available_permits() @@ -95,6 +168,7 @@ pub struct TokenBucketBuilder { capacity: Option, retry_cost: Option, timeout_retry_cost: Option, + success_reward: Option, } impl TokenBucketBuilder { @@ -121,13 +195,24 @@ impl TokenBucketBuilder { self } + /// Sets the reward for any successful request for the builder. + pub fn success_reward(mut self, reward: f64) -> Self { + self.success_reward = Some(reward); + self + } + /// Builds a `TokenBucket`. pub fn build(self) -> TokenBucket { TokenBucket { semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))), max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY), - retry_cost: self.retry_cost.unwrap_or(RETRY_COST), - timeout_retry_cost: self.timeout_retry_cost.unwrap_or(RETRY_TIMEOUT_COST), + retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST), + timeout_retry_cost: self + .timeout_retry_cost + .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST), + success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD), + // fractional_tokens: Arc::new(Mutex::new(0.0)), + fractional_tokens: AtomicF64::new(0.0), } } } @@ -184,4 +269,281 @@ mod tests { // Verify next acquisition fails assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none()); } + + #[test] + fn test_fractional_tokens_accumulate_and_convert() { + let bucket = TokenBucket::builder() + .capacity(10) + .success_reward(0.4) + .build(); + + // acquire 10 tokens to bring capacity below max so we can test accumulation + let _hold_permit = bucket.acquire(&ErrorKind::TransientError); + assert_eq!(bucket.semaphore.available_permits(), 0); + + // First success: 0.4 fractional tokens + bucket.reward_success(); + assert_eq!(bucket.semaphore.available_permits(), 0); + + // Second success: 0.8 fractional tokens + bucket.reward_success(); + assert_eq!(bucket.semaphore.available_permits(), 0); + + // Third success: 1.2 fractional tokens -> 1 full token added + bucket.reward_success(); + assert_eq!(bucket.semaphore.available_permits(), 1); + } + + #[test] + fn test_fractional_tokens_respect_max_capacity() { + let bucket = TokenBucket::builder() + .capacity(10) + .success_reward(2.0) + .build(); + + for _ in 0..20 { + bucket.reward_success(); + } + + assert!(bucket.semaphore.available_permits() == 10); + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_builder_with_custom_values() { + let bucket = TokenBucket::builder() + .capacity(100) + .retry_cost(10) + .timeout_retry_cost(20) + .success_reward(0.5) + .build(); + + assert_eq!(bucket.max_permits, 100); + assert_eq!(bucket.retry_cost, 10); + assert_eq!(bucket.timeout_retry_cost, 20); + assert_eq!(bucket.success_reward, 0.5); + } + + #[test] + fn test_atomicf64_f64_to_bits_conversion_correctness() { + // This is the core functionality + let test_values = vec![ + 0.0, -0.0, 1.0, -1.0, + f64::INFINITY, f64::NEG_INFINITY, f64::NAN, + f64::MIN, f64::MAX, f64::MIN_POSITIVE, f64::EPSILON, + std::f64::consts::PI, std::f64::consts::E, + // Test values that could expose bit manipulation bugs + 1.23456789e-308, // Very small normal number + 1.23456789e308, // Very large number + 2.2250738585072014e-308, // Near MIN_POSITIVE + ]; + + for &expected in &test_values { + let atomic = AtomicF64::new(expected); + let actual = atomic.load(); + + // For NaN, we can't use == but must check bit patterns + if expected.is_nan() { + assert!(actual.is_nan(), "Expected NaN, got {}", actual); + // Different NaN bit patterns should be preserved exactly + assert_eq!(expected.to_bits(), actual.to_bits()); + } else { + assert_eq!(expected.to_bits(), actual.to_bits()); + } + } + } + + #[test] + fn test_atomicf64_store_load_preserves_exact_bits() { + let atomic = AtomicF64::new(0.0); + + // Test that store/load cycle preserves EXACT bit patterns + // This would catch bugs in the to_bits/from_bits conversion + let critical_bit_patterns = vec![ + 0x0000000000000000u64, // +0.0 + 0x8000000000000000u64, // -0.0 + 0x7FF0000000000000u64, // +infinity + 0xFFF0000000000000u64, // -infinity + 0x7FF8000000000000u64, // Quiet NaN + 0x7FF4000000000000u64, // Signaling NaN + 0x0000000000000001u64, // Smallest positive subnormal + 0x000FFFFFFFFFFFFFu64, // Largest subnormal + 0x0010000000000000u64, // Smallest positive normal (MIN_POSITIVE) + ]; + + for &expected_bits in &critical_bit_patterns { + let expected_f64 = f64::from_bits(expected_bits); + atomic.store(expected_f64); + let loaded_f64 = atomic.load(); + let actual_bits = loaded_f64.to_bits(); + + assert_eq!(expected_bits, actual_bits); + } + } + + #[test] + fn test_atomicf64_concurrent_store_load_safety() { + use std::sync::Arc; + use std::thread; + + let atomic = Arc::new(AtomicF64::new(0.0)); + let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let mut handles = Vec::new(); + + // Start multiple threads that continuously write different values + for &value in &test_values { + let atomic_clone = Arc::clone(&atomic); + let handle = thread::spawn(move || { + for _ in 0..1000 { + atomic_clone.store(value); + } + }); + handles.push(handle); + } + + // Start a reader thread that continuously reads + let atomic_reader = Arc::clone(&atomic); + let reader_handle = thread::spawn(move || { + let mut readings = Vec::new(); + for _ in 0..5000 { + let value = atomic_reader.load(); + readings.push(value); + } + readings + }); + + // Wait for all writers to complete + for handle in handles { + handle.join().expect("Writer thread panicked"); + } + + let readings = reader_handle.join().expect("Reader thread panicked"); + + // Verify that all read values are valid (one of the written values) + // This tests that there's no data corruption from concurrent access + for &reading in &readings { + assert!(test_values.contains(&reading) || reading == 0.0); + + // More importantly, verify the reading is a valid f64 + // (not corrupted bits that happen to parse as valid) + assert!(reading.is_finite() || reading == 0.0, + "Corrupted reading detected: {}"); + } + } + + #[test] + fn test_atomicf64_stress_concurrent_access() { + use std::sync::{Arc, Barrier}; + use std::thread; + + let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let atomic = Arc::new(AtomicF64::new(0.0)); + let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads + let mut handles = Vec::new(); + + // Launch threads that all start simultaneously + for i in 0..10 { + let atomic_clone = Arc::clone(&atomic); + let barrier_clone = Arc::clone(&barrier); + let handle = thread::spawn(move || { + barrier_clone.wait(); // All threads start at same time + + // Tight loop increases chance of race conditions + for _ in 0..10000 { + let value = i as f64; + atomic_clone.store(value); + let loaded = atomic_clone.load(); + // Verify no corruption occurred + assert!(loaded >= 0.0 && loaded <= 9.0); + assert!(expected_values.contains(&loaded), + "Got unexpected value: {}, expected one of {:?}", loaded, expected_values); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + } + + + #[test] + fn test_atomicf64_relaxed_ordering_semantics() { + use std::sync::Arc; + use std::thread; + use std::sync::atomic::{AtomicBool, Ordering}; + + // Test that Relaxed ordering doesn't cause obvious problems + // (This is hard to test definitively, but we can check basic operation) + let atomic = Arc::new(AtomicF64::new(1.0)); + let flag = Arc::new(AtomicBool::new(false)); + + let atomic_clone = Arc::clone(&atomic); + let flag_clone = Arc::clone(&flag); + + let writer = thread::spawn(move || { + atomic_clone.store(42.0); + flag_clone.store(true, Ordering::Release); + }); + + let atomic_reader = Arc::clone(&atomic); + let flag_reader = Arc::clone(&flag); + + let reader = thread::spawn(move || { + // Spin until flag is set + while !flag_reader.load(Ordering::Acquire) { + std::hint::spin_loop(); + } + atomic_reader.load() + }); + + writer.join().expect("Writer panicked"); + let final_value = reader.join().expect("Reader panicked"); + + // Due to relaxed ordering on the AtomicF64, we might see the old or new value + assert!(final_value == 1.0 || final_value == 42.0, + "Unexpected value: {}", final_value); + } + + #[test] + fn test_atomicf64_integration_with_token_bucket_usage() { + let atomic = AtomicF64::new(0.0); + let success_reward = 0.3; + let iterations = 5; + + // Accumulate fractional tokens + for i in 1..=iterations { + let current = atomic.load(); + atomic.store(current + success_reward); + } + + let accumulated = atomic.load(); + let expected_total = iterations as f64 * success_reward; // 1.5 + + // Test the floor() operation pattern + let full_tokens = accumulated.floor(); + atomic.store(accumulated - full_tokens); + let remaining = atomic.load(); + + // These assertions should be general: + assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc. + assert!(remaining >= 0.0 && remaining < 1.0); + assert_eq!(remaining, expected_total - expected_total.floor()); + } + + + #[test] + fn test_atomicf64_clone_creates_independent_copy() { + let original = AtomicF64::new(123.456); + let cloned = original.clone(); + + // Verify they start with the same value + assert_eq!(original.load(), cloned.load()); + + // Verify they're independent - modifying one doesn't affect the other + original.store(999.0); + assert_eq!(cloned.load(), 123.456, "Clone should be unaffected by original changes"); + assert_eq!(original.load(), 999.0, "Original should have new value"); + } } From d2c1775bbe691659b33750ec563ce4ab885c2dc3 Mon Sep 17 00:00:00 2001 From: Anna Hayman Date: Fri, 21 Nov 2025 10:29:01 -0500 Subject: [PATCH 2/2] Adding additional functionality for token bucket, including max capacity to accommodate multiple architectures --- .github/workflows/manual-pull-request-bot.yml | 2 +- .../smithy/rustsdk/RetryPartitionTest.kt | 4 +- .../src/client/retries/token_bucket.rs | 293 +++++++++--------- 3 files changed, 153 insertions(+), 146 deletions(-) diff --git a/.github/workflows/manual-pull-request-bot.yml b/.github/workflows/manual-pull-request-bot.yml index e16838ce09..e71ed6602d 100644 --- a/.github/workflows/manual-pull-request-bot.yml +++ b/.github/workflows/manual-pull-request-bot.yml @@ -56,7 +56,7 @@ jobs: contents: read needs: - get-pr-info - runs-on: ubuntu-latest + runs-on: smithy_ubuntu-latest_8-core steps: - uses: GitHubSecurityLab/actions-permissions/monitor@v1 - uses: actions/checkout@v4 diff --git a/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt b/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt index 3a3c022ede..db34547e3e 100644 --- a/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt +++ b/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt @@ -114,6 +114,7 @@ class RetryPartitionTest { "RetryPartition" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::RetryPartition"), "RuntimeComponents" to RuntimeType.runtimeComponents(ctx.runtimeConfig), "TokenBucket" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::TokenBucket"), + "MAXIMUM_CAPACITY" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::token_bucket::MAXIMUM_CAPACITY"), ) crate.integrationTest("custom_retry_partition") { tokioTest("test_custom_token_bucket") { @@ -139,7 +140,8 @@ class RetryPartitionTest { ) -> Result<(), #{BoxError}> { self.called.fetch_add(1, Ordering::Relaxed); let token_bucket = cfg.load::<#{TokenBucket}>().unwrap(); - let expected = format!("permits: {}", tokio::sync::Semaphore::MAX_PERMITS); + let max_capacity = #{MAXIMUM_CAPACITY}; + let expected = format!("permits: {}", max_capacity); assert!( format!("{token_bucket:?}").contains(&expected), "Expected debug output to contain `{expected}`, but got: {token_bucket:?}" diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs index 1beaffedda..c8aa9b5dfc 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs @@ -6,17 +6,23 @@ use aws_smithy_types::config_bag::{Storable, StoreReplace}; use aws_smithy_types::retry::ErrorKind; use std::fmt; -use std::sync::atomic::AtomicU64; +use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering; use std::sync::Arc; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tracing::trace; const DEFAULT_CAPACITY: usize = 500; +// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911. +// Therefore, we will enforce a value lower than that to ensure behavior is +// identical across platforms. +// This also allows room for slight bucket overfill in the case where a bucket +// is at maximum capacity and another thread drops a permit it was holding. +pub const MAXIMUM_CAPACITY: usize = 500_000_000; const DEFAULT_RETRY_COST: u32 = 5; const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2; const PERMIT_REGENERATION_AMOUNT: usize = 1; -const DEFAULT_SUCCESS_REWARD: f64 = 0.0; +const DEFAULT_SUCCESS_REWARD: f32 = 0.0; /// Token bucket used for standard and adaptive retry. #[derive(Clone, Debug)] @@ -25,42 +31,44 @@ pub struct TokenBucket { max_permits: usize, timeout_retry_cost: u32, retry_cost: u32, - success_reward: f64, - fractional_tokens: AtomicF64, + success_reward: f32, + fractional_tokens: AtomicF32, } -pub struct AtomicF64 { - storage: AtomicU64, +struct AtomicF32 { + storage: AtomicU32, } -impl AtomicF64 { - pub fn new(value: f64) -> Self { - let as_u64 = value.to_bits(); - Self { storage: AtomicU64::new(as_u64) } +impl AtomicF32 { + fn new(value: f32) -> Self { + let as_u32 = value.to_bits(); + Self { + storage: AtomicU32::new(as_u32), + } } - pub fn store(&self, value: f64) { - let as_u64 = value.to_bits(); - self.storage.store(as_u64, Ordering::Relaxed) + fn store(&self, value: f32) { + let as_u32 = value.to_bits(); + self.storage.store(as_u32, Ordering::Relaxed) } - pub fn load(&self) -> f64 { - let as_u64 = self.storage.load(Ordering::Relaxed); - f64::from_bits(as_u64) + fn load(&self) -> f32 { + let as_u32 = self.storage.load(Ordering::Relaxed); + f32::from_bits(as_u32) } } -impl fmt::Debug for AtomicF64 { +impl fmt::Debug for AtomicF32 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Use debug_struct, debug_tuple, or write! for formatting - f.debug_struct("AtomicF64") + f.debug_struct("AtomicF32") .field("value", &self.load()) .finish() } } -impl Clone for AtomicF64 { +impl Clone for AtomicF32 { fn clone(&self) -> Self { // Manually clone each field - AtomicF64 { - storage: AtomicU64::new(self.storage.load(Ordering::Relaxed)), + AtomicF32 { + storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)), } } } @@ -77,7 +85,7 @@ impl Default for TokenBucket { timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST, retry_cost: DEFAULT_RETRY_COST, success_reward: DEFAULT_SUCCESS_REWARD, - fractional_tokens: AtomicF64::new(0.0), + fractional_tokens: AtomicF32::new(0.0), } } } @@ -95,12 +103,12 @@ impl TokenBucket { /// A token bucket with unlimited capacity that allows retries at no cost. pub fn unlimited() -> Self { Self { - semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)), - max_permits: Semaphore::MAX_PERMITS, + semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)), + max_permits: MAXIMUM_CAPACITY, timeout_retry_cost: 0, retry_cost: 0, success_reward: 0.0, - fractional_tokens: AtomicF64::new(0.0), + fractional_tokens: AtomicF32::new(0.0), } } @@ -110,6 +118,16 @@ impl TokenBucket { } pub(crate) fn acquire(&self, err: &ErrorKind) -> Option { + // We have to handle the case where the number of permits in the semaphore exceeds the intended + // max. This can occur when the bucket is already at max capacity (success reward > 0) and then an + // OwnedSemaphorePermit gets dropped (destroyed), automatically returning its permits to the + // semaphore and causing it to exceed max_permits. + let available_permits = self.semaphore.available_permits(); + if available_permits > self.max_permits { + self.semaphore + .forget_permits(available_permits - self.max_permits); + } + let retry_cost = if err == &ErrorKind::TransientError { self.timeout_retry_cost } else { @@ -127,23 +145,25 @@ impl TokenBucket { } pub(crate) fn reward_success(&self) { - // Verify that fractional tokens have not become corrupted - if !self.fractional_tokens.load().is_finite() { - tracing::error!("Fractional tokens corrupted to: {}", self.fractional_tokens.load()); - // If corrupted, reset to the number of permits the bucket was created with - self.fractional_tokens.store(self.max_permits as f64); + let mut calc_fractional_tokens = self.fractional_tokens.load(); + // Verify that fractional tokens have not become corrupted - if they have, reset to zero + if !calc_fractional_tokens.is_finite() { + tracing::error!("Fractional tokens corrupted to: {}, resetting to 0.0", calc_fractional_tokens); + self.fractional_tokens.store(0.0); return; } if self.success_reward > 0.0 { - self.fractional_tokens.store(self.fractional_tokens.load() + self.success_reward); + calc_fractional_tokens += self.success_reward; } - let full_tokens_accumulated = self.fractional_tokens.load().floor(); + let full_tokens_accumulated = calc_fractional_tokens.floor(); if full_tokens_accumulated >= 1.0 { self.add_tokens(full_tokens_accumulated as usize); - self.fractional_tokens.store(self.fractional_tokens.load() - full_tokens_accumulated); + calc_fractional_tokens -= full_tokens_accumulated; } + // Always store the updated fractional tokens back, even if no conversion happened + self.fractional_tokens.store(calc_fractional_tokens); } fn add_tokens(&self, amount: usize) { @@ -168,7 +188,7 @@ pub struct TokenBucketBuilder { capacity: Option, retry_cost: Option, timeout_retry_cost: Option, - success_reward: Option, + success_reward: Option, } impl TokenBucketBuilder { @@ -178,7 +198,10 @@ impl TokenBucketBuilder { } /// Sets the maximum bucket capacity for the builder. - pub fn capacity(mut self, capacity: usize) -> Self { + pub fn capacity(mut self, mut capacity: usize) -> Self { + if capacity > MAXIMUM_CAPACITY { + capacity = MAXIMUM_CAPACITY; + } self.capacity = Some(capacity); self } @@ -196,7 +219,7 @@ impl TokenBucketBuilder { } /// Sets the reward for any successful request for the builder. - pub fn success_reward(mut self, reward: f64) -> Self { + pub fn success_reward(mut self, reward: f32) -> Self { self.success_reward = Some(reward); self } @@ -211,8 +234,7 @@ impl TokenBucketBuilder { .timeout_retry_cost .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST), success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD), - // fractional_tokens: Arc::new(Mutex::new(0.0)), - fractional_tokens: AtomicF64::new(0.0), + fractional_tokens: AtomicF32::new(0.0), } } } @@ -230,7 +252,7 @@ mod tests { assert!(bucket.acquire(&ErrorKind::TransientError).is_some()); // Should have maximum capacity - assert_eq!(bucket.max_permits, Semaphore::MAX_PERMITS); + assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY); // Should have zero retry costs assert_eq!(bucket.retry_cost, 0); @@ -244,7 +266,7 @@ mod tests { permits.push(permit); // Available permits should stay constant assert_eq!( - tokio::sync::Semaphore::MAX_PERMITS, + MAXIMUM_CAPACITY, bucket.semaphore.available_permits() ); } @@ -325,23 +347,32 @@ mod tests { } #[test] - fn test_atomicf64_f64_to_bits_conversion_correctness() { + fn test_atomicf32_f32_to_bits_conversion_correctness() { // This is the core functionality let test_values = vec![ - 0.0, -0.0, 1.0, -1.0, - f64::INFINITY, f64::NEG_INFINITY, f64::NAN, - f64::MIN, f64::MAX, f64::MIN_POSITIVE, f64::EPSILON, - std::f64::consts::PI, std::f64::consts::E, + 0.0, + -0.0, + 1.0, + -1.0, + f32::INFINITY, + f32::NEG_INFINITY, + f32::NAN, + f32::MIN, + f32::MAX, + f32::MIN_POSITIVE, + f32::EPSILON, + std::f32::consts::PI, + std::f32::consts::E, // Test values that could expose bit manipulation bugs - 1.23456789e-308, // Very small normal number - 1.23456789e308, // Very large number - 2.2250738585072014e-308, // Near MIN_POSITIVE + 1.23456789e-38, // Very small normal number + 1.23456789e38, // Very large number (within f32 range) + 1.1754944e-38, // Near MIN_POSITIVE for f32 ]; - + for &expected in &test_values { - let atomic = AtomicF64::new(expected); + let atomic = AtomicF32::new(expected); let actual = atomic.load(); - + // For NaN, we can't use == but must check bit patterns if expected.is_nan() { assert!(actual.is_nan(), "Expected NaN, got {}", actual); @@ -353,43 +384,45 @@ mod tests { } } + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] #[test] - fn test_atomicf64_store_load_preserves_exact_bits() { - let atomic = AtomicF64::new(0.0); - + fn test_atomicf32_store_load_preserves_exact_bits() { + let atomic = AtomicF32::new(0.0); + // Test that store/load cycle preserves EXACT bit patterns // This would catch bugs in the to_bits/from_bits conversion let critical_bit_patterns = vec![ - 0x0000000000000000u64, // +0.0 - 0x8000000000000000u64, // -0.0 - 0x7FF0000000000000u64, // +infinity - 0xFFF0000000000000u64, // -infinity - 0x7FF8000000000000u64, // Quiet NaN - 0x7FF4000000000000u64, // Signaling NaN - 0x0000000000000001u64, // Smallest positive subnormal - 0x000FFFFFFFFFFFFFu64, // Largest subnormal - 0x0010000000000000u64, // Smallest positive normal (MIN_POSITIVE) + 0x00000000u32, // +0.0 + 0x80000000u32, // -0.0 + 0x7F800000u32, // +infinity + 0xFF800000u32, // -infinity + 0x7FC00000u32, // Quiet NaN + 0x7FA00000u32, // Signaling NaN + 0x00000001u32, // Smallest positive subnormal + 0x007FFFFFu32, // Largest subnormal + 0x00800000u32, // Smallest positive normal (MIN_POSITIVE) ]; - + for &expected_bits in &critical_bit_patterns { - let expected_f64 = f64::from_bits(expected_bits); - atomic.store(expected_f64); - let loaded_f64 = atomic.load(); - let actual_bits = loaded_f64.to_bits(); - + let expected_f32 = f32::from_bits(expected_bits); + atomic.store(expected_f32); + let loaded_f32 = atomic.load(); + let actual_bits = loaded_f32.to_bits(); + assert_eq!(expected_bits, actual_bits); } } - #[test] - fn test_atomicf64_concurrent_store_load_safety() { + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_atomicf32_concurrent_store_load_safety() { use std::sync::Arc; use std::thread; - - let atomic = Arc::new(AtomicF64::new(0.0)); + + let atomic = Arc::new(AtomicF32::new(0.0)); let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let mut handles = Vec::new(); - + // Start multiple threads that continuously write different values for &value in &test_values { let atomic_clone = Arc::clone(&atomic); @@ -400,7 +433,7 @@ mod tests { }); handles.push(handle); } - + // Start a reader thread that continuously reads let atomic_reader = Arc::clone(&atomic); let reader_handle = thread::spawn(move || { @@ -411,139 +444,111 @@ mod tests { } readings }); - + // Wait for all writers to complete for handle in handles { handle.join().expect("Writer thread panicked"); } - + let readings = reader_handle.join().expect("Reader thread panicked"); - + // Verify that all read values are valid (one of the written values) // This tests that there's no data corruption from concurrent access for &reading in &readings { assert!(test_values.contains(&reading) || reading == 0.0); - - // More importantly, verify the reading is a valid f64 + + // More importantly, verify the reading is a valid f32 // (not corrupted bits that happen to parse as valid) - assert!(reading.is_finite() || reading == 0.0, - "Corrupted reading detected: {}"); + assert!( + reading.is_finite() || reading == 0.0, + "Corrupted reading detected" + ); } } + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] #[test] - fn test_atomicf64_stress_concurrent_access() { + fn test_atomicf32_stress_concurrent_access() { use std::sync::{Arc, Barrier}; use std::thread; - + let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; - let atomic = Arc::new(AtomicF64::new(0.0)); + let atomic = Arc::new(AtomicF32::new(0.0)); let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads let mut handles = Vec::new(); - + // Launch threads that all start simultaneously for i in 0..10 { let atomic_clone = Arc::clone(&atomic); let barrier_clone = Arc::clone(&barrier); let handle = thread::spawn(move || { barrier_clone.wait(); // All threads start at same time - + // Tight loop increases chance of race conditions for _ in 0..10000 { - let value = i as f64; + let value = i as f32; atomic_clone.store(value); let loaded = atomic_clone.load(); // Verify no corruption occurred assert!(loaded >= 0.0 && loaded <= 9.0); - assert!(expected_values.contains(&loaded), - "Got unexpected value: {}, expected one of {:?}", loaded, expected_values); + assert!( + expected_values.contains(&loaded), + "Got unexpected value: {}, expected one of {:?}", + loaded, + expected_values + ); } }); handles.push(handle); } - + for handle in handles { handle.join().unwrap(); } } - - #[test] - fn test_atomicf64_relaxed_ordering_semantics() { - use std::sync::Arc; - use std::thread; - use std::sync::atomic::{AtomicBool, Ordering}; - - // Test that Relaxed ordering doesn't cause obvious problems - // (This is hard to test definitively, but we can check basic operation) - let atomic = Arc::new(AtomicF64::new(1.0)); - let flag = Arc::new(AtomicBool::new(false)); - - let atomic_clone = Arc::clone(&atomic); - let flag_clone = Arc::clone(&flag); - - let writer = thread::spawn(move || { - atomic_clone.store(42.0); - flag_clone.store(true, Ordering::Release); - }); - - let atomic_reader = Arc::clone(&atomic); - let flag_reader = Arc::clone(&flag); - - let reader = thread::spawn(move || { - // Spin until flag is set - while !flag_reader.load(Ordering::Acquire) { - std::hint::spin_loop(); - } - atomic_reader.load() - }); - - writer.join().expect("Writer panicked"); - let final_value = reader.join().expect("Reader panicked"); - - // Due to relaxed ordering on the AtomicF64, we might see the old or new value - assert!(final_value == 1.0 || final_value == 42.0, - "Unexpected value: {}", final_value); - } - #[test] - fn test_atomicf64_integration_with_token_bucket_usage() { - let atomic = AtomicF64::new(0.0); + fn test_atomicf32_integration_with_token_bucket_usage() { + let atomic = AtomicF32::new(0.0); let success_reward = 0.3; let iterations = 5; - + // Accumulate fractional tokens - for i in 1..=iterations { + for _ in 1..=iterations { let current = atomic.load(); atomic.store(current + success_reward); } - + let accumulated = atomic.load(); - let expected_total = iterations as f64 * success_reward; // 1.5 - + let expected_total = iterations as f32 * success_reward; // 1.5 + // Test the floor() operation pattern let full_tokens = accumulated.floor(); atomic.store(accumulated - full_tokens); let remaining = atomic.load(); - + // These assertions should be general: assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc. assert!(remaining >= 0.0 && remaining < 1.0); assert_eq!(remaining, expected_total - expected_total.floor()); } - + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] #[test] - fn test_atomicf64_clone_creates_independent_copy() { - let original = AtomicF64::new(123.456); + fn test_atomicf32_clone_creates_independent_copy() { + let original = AtomicF32::new(123.456); let cloned = original.clone(); - + // Verify they start with the same value assert_eq!(original.load(), cloned.load()); - + // Verify they're independent - modifying one doesn't affect the other original.store(999.0); - assert_eq!(cloned.load(), 123.456, "Clone should be unaffected by original changes"); + assert_eq!( + cloned.load(), + 123.456, + "Clone should be unaffected by original changes" + ); assert_eq!(original.load(), 999.0, "Original should have new value"); } }