|
| 1 | +/// This module implements the googletest test sharding protocol. The Google |
| 2 | +/// sharding protocol consists of the following environment variables: |
| 3 | +/// |
| 4 | +/// * GTEST_TOTAL_SHARDS: total number of shards. |
| 5 | +/// * GTEST_SHARD_INDEX: number of this shard |
| 6 | +/// * GTEST_SHARD_STATUS_FILE: touch this file to indicate support for sharding. |
| 7 | +/// |
| 8 | +/// See also <https://google.github.io/googletest/advanced.html> |
| 9 | +use std::cell::OnceCell; |
| 10 | +use std::env::{var, var_os}; |
| 11 | +use std::ffi::OsStr; |
| 12 | +use std::fs::{self, File}; |
| 13 | +use std::num::NonZeroU64; |
| 14 | +use std::path::{Path, PathBuf}; |
| 15 | + |
| 16 | +/// Environment variable specifying the total number of test shards. |
| 17 | +const TEST_TOTAL_SHARDS: &str = "GTEST_TOTAL_SHARDS"; |
| 18 | + |
| 19 | +/// Environment variable specifyign the index of this test shard. |
| 20 | +const TEST_SHARD_INDEX: &str = "GTEST_SHARD_INDEX"; |
| 21 | + |
| 22 | +/// Environment variable specifying the name of the file we create (or cause a |
| 23 | +/// timestamp change on) to indicate that we support the sharding protocol. |
| 24 | +const TEST_SHARD_STATUS_FILE: &str = "GTEST_SHARD_STATUS_FILE"; |
| 25 | + |
| 26 | +thread_local! { |
| 27 | + static SHARDING: OnceCell<Sharding> = const { OnceCell::new() }; |
| 28 | +} |
| 29 | + |
| 30 | +struct Sharding { |
| 31 | + this_shard: u64, |
| 32 | + total_shards: NonZeroU64, |
| 33 | +} |
| 34 | + |
| 35 | +impl Default for Sharding { |
| 36 | + fn default() -> Self { |
| 37 | + Self { this_shard: 0, total_shards: NonZeroU64::MIN } |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | +pub fn test_should_run(test_case_hash: u64) -> bool { |
| 42 | + SHARDING.with(|sharding_cell| { |
| 43 | + sharding_cell.get_or_init(Sharding::from_environment).test_should_run(test_case_hash) |
| 44 | + }) |
| 45 | +} |
| 46 | + |
| 47 | +impl Sharding { |
| 48 | + fn test_should_run(&self, test_case_hash: u64) -> bool { |
| 49 | + (test_case_hash % self.total_shards.get()) == self.this_shard |
| 50 | + } |
| 51 | + |
| 52 | + fn from_environment() -> Sharding { |
| 53 | + let this_shard: Option<u64> = |
| 54 | + { var(OsStr::new(TEST_SHARD_INDEX)).ok().and_then(|value| value.parse().ok()) }; |
| 55 | + let total_shards: Option<NonZeroU64> = { |
| 56 | + var(OsStr::new(TEST_TOTAL_SHARDS)) |
| 57 | + .ok() |
| 58 | + .and_then(|value| value.parse().ok()) |
| 59 | + .and_then(NonZeroU64::new) |
| 60 | + }; |
| 61 | + |
| 62 | + match (this_shard, total_shards) { |
| 63 | + (Some(this_shard), Some(total_shards)) if this_shard < total_shards.get() => { |
| 64 | + if let Some(name) = var_os(OsStr::new(TEST_SHARD_STATUS_FILE)) { |
| 65 | + let pathbuf = PathBuf::from(name); |
| 66 | + if let Err(e) = create_status_file(&pathbuf) { |
| 67 | + eprintln!( |
| 68 | + "failed to create {} file {}: {}", |
| 69 | + TEST_SHARD_STATUS_FILE, |
| 70 | + pathbuf.display(), |
| 71 | + e |
| 72 | + ); |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + Sharding { this_shard, total_shards } |
| 77 | + } |
| 78 | + _ => Sharding::default(), |
| 79 | + } |
| 80 | + } |
| 81 | +} |
| 82 | + |
| 83 | +fn create_status_file(path: &Path) -> std::io::Result<()> { |
| 84 | + if let Some(parent) = path.parent() { |
| 85 | + fs::create_dir_all(parent)?; |
| 86 | + } |
| 87 | + |
| 88 | + File::create(path).map(|_| ()) |
| 89 | +} |
0 commit comments