diff --git a/sdk/core/azure_core/src/http/mod.rs b/sdk/core/azure_core/src/http/mod.rs index 9ad5ce8009..465ab60825 100644 --- a/sdk/core/azure_core/src/http/mod.rs +++ b/sdk/core/azure_core/src/http/mod.rs @@ -23,7 +23,7 @@ pub use response::{AsyncRawResponse, AsyncResponse, RawResponse, Response}; pub use typespec_client_core::http::response; pub use typespec_client_core::http::{ new_http_client, AppendToUrlQuery, Context, DeserializeWith, Format, HttpClient, JsonFormat, - Method, NoFormat, StatusCode, Url, UrlExt, + Method, NoFormat, Sanitizer, StatusCode, Url, UrlExt, }; pub use crate::error::check_success; diff --git a/sdk/core/azure_core/src/http/pager.rs b/sdk/core/azure_core/src/http/pager.rs index 9ae1b94f95..719ff79d93 100644 --- a/sdk/core/azure_core/src/http/pager.rs +++ b/sdk/core/azure_core/src/http/pager.rs @@ -811,14 +811,33 @@ impl

fmt::Debug for PageIterator

{ } } -#[derive(Debug, Clone, Eq)] -enum State { +#[derive(Clone, Eq)] +enum State +where + C: AsRef, +{ Init, More(C), Done, } -impl PartialEq for State { +impl fmt::Debug for State +where + C: AsRef, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::Init => write!(f, "Init"), + State::More(c) => f.debug_tuple("More").field(&c.as_ref()).finish(), + State::Done => write!(f, "Done"), + } + } +} + +impl PartialEq for State +where + C: AsRef, +{ fn eq(&self, other: &Self) -> bool { // Only needs to compare if both states are Init or Done; internally, we don't care about any other states. matches!( @@ -829,7 +848,10 @@ impl PartialEq for State { } #[derive(Debug)] -struct StreamState<'a, C, F> { +struct StreamState<'a, C, F> +where + C: AsRef, +{ state: State, make_request: F, continuation_token: Arc>>, @@ -863,9 +885,21 @@ where added_span: false, }, |mut stream_state| async move { + // When in the "Init" state, we are either starting fresh or resuming from a continuation token. In either case, + // attach a span to the context for the entire paging operation. + if stream_state.state == State::Init { + tracing::debug!("establish a public API span for new pager."); + + // At the very start of polling, create a span for the entire request, and attach it to the context + let span = create_public_api_span(&stream_state.ctx, None, None); + if let Some(ref s) = span { + stream_state.added_span = true; + stream_state.ctx = stream_state.ctx.with_value(s.clone()); + } + } + // Get the `continuation_token` to pick up where we left off, or None for the initial page, // but don't override the terminal `State::Done`. - if stream_state.state != State::Done { let result = match stream_state.continuation_token.lock() { Ok(next_token) => match next_token.as_deref() { @@ -895,12 +929,6 @@ where let result = match stream_state.state { State::Init => { tracing::debug!("initial page request"); - // At the very start of polling, create a span for the entire request, and attach it to the context - let span = create_public_api_span(&stream_state.ctx, None, None); - if let Some(ref s) = span { - stream_state.added_span = true; - stream_state.ctx = stream_state.ctx.with_value(s.clone()); - } (stream_state.make_request)(PagerState::Initial, stream_state.ctx.clone()).await } State::More(n) => { diff --git a/sdk/core/azure_core/src/http/policies/instrumentation/public_api_instrumentation.rs b/sdk/core/azure_core/src/http/policies/instrumentation/public_api_instrumentation.rs index 3c9abbfb15..a577cd1ae0 100644 --- a/sdk/core/azure_core/src/http/policies/instrumentation/public_api_instrumentation.rs +++ b/sdk/core/azure_core/src/http/policies/instrumentation/public_api_instrumentation.rs @@ -555,8 +555,7 @@ mod tests { status: SpanStatus::Unset, kind: SpanKind::Internal, span_id: Uuid::new_v4(), - parent_id: None, - attributes: vec![], + ..Default::default() }], }], ) @@ -599,9 +598,9 @@ mod tests { span_name: "MyClient.MyApi", status: SpanStatus::Unset, span_id: Uuid::new_v4(), - parent_id: None, kind: SpanKind::Internal, attributes: vec![(AZ_NAMESPACE_ATTRIBUTE, "test namespace".into())], + ..Default::default() }], }], ); @@ -652,6 +651,7 @@ mod tests { (AZ_NAMESPACE_ATTRIBUTE, "test namespace".into()), (ERROR_TYPE_ATTRIBUTE, "500".into()), ], + ..Default::default() }], }], ); @@ -701,6 +701,7 @@ mod tests { (AZ_NAMESPACE_ATTRIBUTE, "test.namespace".into()), ("az.fake_attribute", "attribute value".into()), ], + ..Default::default() }, ExpectedSpanInformation { span_name: "PUT", @@ -716,6 +717,7 @@ mod tests { ("server.port", 80.into()), ("http.response.status_code", 200.into()), ], + ..Default::default() }, ], }], diff --git a/sdk/core/azure_core/src/http/policies/instrumentation/request_instrumentation.rs b/sdk/core/azure_core/src/http/policies/instrumentation/request_instrumentation.rs index 442f5cdab3..d221ff0eaf 100644 --- a/sdk/core/azure_core/src/http/policies/instrumentation/request_instrumentation.rs +++ b/sdk/core/azure_core/src/http/policies/instrumentation/request_instrumentation.rs @@ -302,6 +302,7 @@ pub(crate) mod tests { ), ), ], + ..Default::default() }], }], ); @@ -388,6 +389,7 @@ pub(crate) mod tests { AttributeValue::from("https://example.com/client_request_id"), ), ], + ..Default::default() }], }], ); @@ -433,6 +435,7 @@ pub(crate) mod tests { (SERVER_ADDRESS_ATTRIBUTE, AttributeValue::from("host")), (SERVER_PORT_ATTRIBUTE, AttributeValue::from(8080)), ], + ..Default::default() }], }], ); @@ -502,6 +505,7 @@ pub(crate) mod tests { AttributeValue::from("https://microsoft.com/request_failed.htm"), ), ], + ..Default::default() }], }], ); diff --git a/sdk/core/azure_core_opentelemetry/tests/telemetry_service_macros.rs b/sdk/core/azure_core_opentelemetry/tests/telemetry_service_macros.rs index 4f5919a565..0fa1ff590f 100644 --- a/sdk/core/azure_core_opentelemetry/tests/telemetry_service_macros.rs +++ b/sdk/core/azure_core_opentelemetry/tests/telemetry_service_macros.rs @@ -685,10 +685,7 @@ mod tests { let package_version = env!("CARGO_PKG_VERSION").to_string(); azure_core_test::tracing::assert_instrumentation_information( |tracer_provider| Ok(create_service_client(&ctx, tracer_provider)), - |client| { - let client = client; - Box::pin(async move { client.get("get", None).await }) - }, + async move |client| client.get("get", None).await, ExpectedInstrumentation { package_name, package_version, @@ -706,14 +703,11 @@ mod tests { #[recorded::test()] async fn test_function_tracing_tests(ctx: TestContext) -> Result<()> { - let package_name = env!("CARGO_PKG_NAME").to_string(); + let package_name = ctx.recording().var("CARGO_PKG_NAME", None).to_string(); let package_version = env!("CARGO_PKG_VERSION").to_string(); azure_core_test::tracing::assert_instrumentation_information( |tracer_provider| Ok(create_service_client(&ctx, tracer_provider)), - |client| { - let client = client; - Box::pin(async move { client.get_with_function_tracing("get", None).await }) - }, + async move |client| client.get_with_function_tracing("get", None).await, ExpectedInstrumentation { package_name, package_version, @@ -737,14 +731,11 @@ mod tests { async fn test_function_tracing_tests_error(ctx: TestContext) -> Result<()> { use azure_core_test::tracing::ExpectedRestApiSpan; - let package_name = env!("CARGO_PKG_NAME").to_string(); + let package_name = ctx.recording().var("CARGO_PKG_NAME", None).to_string(); let package_version = env!("CARGO_PKG_VERSION").to_string(); azure_core_test::tracing::assert_instrumentation_information( |tracer_provider| Ok(create_service_client(&ctx, tracer_provider)), - |client| { - let client = client; - Box::pin(async move { client.get_with_function_tracing("index.htm", None).await }) - }, + async move |client| client.get_with_function_tracing("index.htm", None).await, ExpectedInstrumentation { package_name, package_version, diff --git a/sdk/core/azure_core_test/src/tracing.rs b/sdk/core/azure_core_test/src/tracing.rs index dc4296144f..484ed5823c 100644 --- a/sdk/core/azure_core_test/src/tracing.rs +++ b/sdk/core/azure_core_test/src/tracing.rs @@ -18,7 +18,6 @@ use std::{ borrow::Cow, collections::HashMap, fmt::Debug, - pin::Pin, sync::{Arc, Mutex}, }; @@ -65,10 +64,10 @@ impl TracerProvider for MockTracingProvider { /// Mock Tracer - used for testing distributed tracing without involving a specific tracing implementation. #[derive(Debug)] pub struct MockTracer { - pub namespace: Option<&'static str>, - pub package_name: &'static str, - pub package_version: Option<&'static str>, - pub spans: Mutex>>, + namespace: Option<&'static str>, + package_name: &'static str, + package_version: Option<&'static str>, + spans: Mutex>>, } impl Tracer for MockTracer { @@ -83,9 +82,14 @@ impl Tracer for MockTracer { attributes: Vec, parent: Arc, ) -> Arc { - let span = Arc::new(MockSpan::new(name, kind, attributes.clone(), Some(parent))); + let span = Arc::new(MockSpanInner::new( + name, + kind, + attributes.clone(), + Some(parent), + )); self.spans.lock().unwrap().push(span.clone()); - span + Arc::new(MockSpan { inner: span }) } fn start_span( @@ -101,15 +105,15 @@ impl Tracer for MockTracer { value: attr.value.clone(), }) .collect(); - let span = Arc::new(MockSpan::new(name, kind, attributes, None)); + let span = Arc::new(MockSpanInner::new(name, kind, attributes, None)); self.spans.lock().unwrap().push(span.clone()); - span + Arc::new(MockSpan { inner: span }) } } /// Mock span for testing purposes. #[derive(Debug)] -pub struct MockSpan { +struct MockSpanInner { pub name: Cow<'static, str>, pub kind: SpanKind, pub parent: Option<[u8; 8]>, @@ -118,7 +122,7 @@ pub struct MockSpan { pub state: Mutex, pub is_open: Mutex, } -impl MockSpan { +impl MockSpanInner { fn new( name: C, kind: SpanKind, @@ -144,9 +148,22 @@ impl MockSpan { is_open: Mutex::new(true), } } + + fn is_open(&self) -> bool { + let is_open = self.is_open.lock().unwrap(); + *is_open + } } -impl Span for MockSpan { +impl AsAny for MockSpanInner { + fn as_any(&self) -> &dyn std::any::Any { + // Convert to an object that doesn't expose the lifetime parameter + // We're essentially erasing the lifetime here to satisfy the static requirement + self as &dyn std::any::Any + } +} + +impl Span for MockSpanInner { fn set_attribute(&self, key: &'static str, value: AttributeValue) { eprintln!("{}: Setting attribute {}: {:?}", self.name, key, value); let mut attributes = self.attributes.lock().unwrap(); @@ -195,6 +212,19 @@ impl Span for MockSpan { } } +pub struct MockSpan { + inner: Arc, +} + +impl Drop for MockSpan { + fn drop(&mut self) { + if self.inner.is_open() { + eprintln!("Warning: Dropping open span: {}", self.inner.name); + self.inner.end(); + } + } +} + impl AsAny for MockSpan { fn as_any(&self) -> &dyn std::any::Any { // Convert to an object that doesn't expose the lifetime parameter @@ -203,6 +233,40 @@ impl AsAny for MockSpan { } } +impl Span for MockSpan { + fn set_attribute(&self, key: &'static str, value: AttributeValue) { + self.inner.set_attribute(key, value); + } + + fn set_status(&self, status: crate::tracing::SpanStatus) { + self.inner.set_status(status); + } + + fn end(&self) { + self.inner.end(); + } + + fn is_recording(&self) -> bool { + self.inner.is_recording() + } + + fn span_id(&self) -> [u8; 8] { + self.inner.span_id() + } + + fn record_error(&self, error: &dyn std::error::Error) { + self.inner.record_error(error); + } + + fn set_current(&self, context: &Context) -> Box { + self.inner.set_current(context) + } + + fn propagate_headers(&self, request: &mut Request) { + self.inner.propagate_headers(request); + } +} + /// Expected information about a tracer. #[derive(Debug)] pub struct ExpectedTracerInformation<'a> { @@ -252,22 +316,64 @@ pub fn check_instrumentation_result( assert_eq!(tracer.namespace, expected.namespace); let spans = tracer.spans.lock().unwrap(); - assert_eq!( - spans.len(), - expected.spans.len(), - "Unexpected number of spans for tracer {}", - expected.name - ); - for (span_index, span_expected) in expected.spans.iter().enumerate() { + // Check span lengths if there are no wildcard spans. + if !expected.spans.iter().any(|s| s.is_wildcard) { + assert_eq!( + spans.len(), + expected.spans.len(), + "Unexpected number of spans for tracer {}", + expected.name + ); + } + + let mut expected_index = 0; + for (span_index, span_actual) in spans.iter().enumerate() { eprintln!( "Checking span {} of tracer {}: {}", - span_index, expected.name, span_expected.span_name + span_index, expected.name, span_actual.name + ); + check_span_information( + span_actual, + &expected.spans[expected_index], + &parent_span_map, ); - check_span_information(&spans[span_index], span_expected, &parent_span_map); // Now that we've verified the span, add the mapping between expected span ID and the actual span ID. - parent_span_map.insert(span_expected.span_id, spans[span_index].id); + parent_span_map.insert(expected.spans[expected_index].span_id, span_actual.id); + if expected.spans[expected_index].is_wildcard { + // If this is a wildcard span, we don't increment the expected index. + eprintln!( + "Span {} is a wildcard, not incrementing expected index", + span_actual.name + ); + if spans.len() > span_index + 1 { + let next_span = &spans[span_index + 1]; + if !compare_span_information( + next_span, + &expected.spans[expected_index], + &parent_span_map, + ) { + eprintln!( + "Next actual span does not match expected span: {}", + expected.spans[expected_index].span_name + ); + expected_index += 1; + } + } else { + // At the very end, bump the expected index past the wildcard entry. + // This ensures that we consume all the expected spans. + expected_index += 1; + } + } else { + expected_index += 1; + } } + assert_eq!( + expected_index, + expected.spans.len(), + "Not all expected spans were found for tracer {}", + expected.name + ); } } @@ -290,10 +396,26 @@ pub struct ExpectedSpanInformation<'a> { /// The expected attributes associated with the span. pub attributes: Vec<(&'a str, AttributeValue)>, + + pub is_wildcard: bool, +} + +impl Default for ExpectedSpanInformation<'_> { + fn default() -> Self { + Self { + span_name: "get", + status: SpanStatus::Unset, + span_id: Uuid::new_v4(), + parent_id: None, + kind: SpanKind::Client, + attributes: vec![], + is_wildcard: false, + } + } } fn check_span_information( - span: &Arc, + span: &Arc, expected: &ExpectedSpanInformation<'_>, parent_span_map: &HashMap, ) { @@ -342,6 +464,64 @@ fn check_span_information( ); } +/// Returns true if the spans match, false otherwise. +fn compare_span_information( + actual: &Arc, + expected: &ExpectedSpanInformation<'_>, + parent_span_map: &HashMap, +) -> bool { + if actual.name != expected.span_name { + return false; + } + if actual.kind != expected.kind { + return false; + } + if *actual.state.lock().unwrap() != expected.status { + return false; + } + match actual.parent { + None => { + if expected.parent_id.is_some() { + return false; + } + } + Some(ref parent) => { + let parent_id = parent_span_map + .get(expected.parent_id.as_ref().unwrap()) + .unwrap(); + if *parent != *parent_id { + return false; + } + } + } + let attributes = actual.attributes.lock().unwrap(); + eprintln!("Expected attributes: {:?}", expected.attributes); + eprintln!("Found attributes: {:?}", attributes); + for (index, attr) in attributes.iter().enumerate() { + eprintln!("Attribute {}: {} = {:?}", index, attr.key, attr.value); + let mut found = false; + for (key, value) in &expected.attributes { + if attr.key == *key { + // Skip checking the value for "" as it is a placeholder + if *value != AttributeValue::String("".into()) && attr.value != *value { + return false; + } + found = true; + break; + } + } + if !found { + return false; + } + } + for (key, _) in expected.attributes.iter() { + if !attributes.iter().any(|attr| attr.key == *key) { + return false; + } + } + true +} + /// Information about an instrumented API call. /// /// This structure is used to collect information about a specific API call that is being instrumented for tracing. @@ -390,6 +570,9 @@ pub struct ExpectedRestApiSpan { /// Expected status code returned by the service. pub expected_status_code: azure_core::http::StatusCode, + + /// Whether an unknown multiple of this span will be found. + pub is_wildcard: bool, } impl Default for ExpectedRestApiSpan { @@ -397,6 +580,7 @@ impl Default for ExpectedRestApiSpan { Self { api_verb: azure_core::http::Method::Get, expected_status_code: azure_core::http::StatusCode::Ok, + is_wildcard: false, } } } @@ -405,8 +589,14 @@ impl Default for ExpectedRestApiSpan { #[derive(Debug, Default, Clone)] pub struct ExpectedInstrumentation { /// The package name for the service client. + /// + /// **NOTE**: Make sure that the package name comes from `env!("CARGO_PKG_NAME")` to ensure that this continues to work + /// if test recordings were created with a previous version of the package. pub package_name: String, /// The package version for the service client. + /// + /// **NOTE**: Make sure that the package version comes from `env!("CARGO_PKG_VERSION")` to ensure that this continues to work + /// if test recordings were created with a previous version of the package. pub package_version: String, /// The namespace for the service client. pub package_namespace: Option<&'static str>, @@ -437,6 +627,7 @@ pub struct ExpectedInstrumentation { /// The `test_api` call may issue multiple service client calls, if it does, this function will verify that all expected spans were created. The caller of the `test_instrumentation_for_api` call /// should make sure to include all expected APIs in the call. /// +/// pub async fn assert_instrumentation_information( create_client: FnInit, test_api: FnTest, @@ -444,7 +635,7 @@ pub async fn assert_instrumentation_information( ) -> azure_core::Result<()> where FnInit: FnOnce(Arc) -> azure_core::Result, - FnTest: FnOnce(C) -> Pin>>>, + FnTest: AsyncFnOnce(C) -> azure_core::Result, { // Initialize the mock tracer provider let mock_tracer = Arc::new(MockTracingProvider::new()); @@ -512,6 +703,7 @@ where status: span_status, kind: SpanKind::Internal, parent_id: None, + is_wildcard: false, // Public API spans cannot be wildcards. attributes: public_api_attributes, }); } @@ -554,6 +746,9 @@ where } else { None }, + // If allow_unknown_children is set, we don't know how many child spans there will be. + // Use a wildcard span ID to indicate that. + is_wildcard: rest_api_call.is_wildcard, span_id: Uuid::new_v4(), status: if !rest_api_call.expected_status_code.is_success() { SpanStatus::Error { diff --git a/sdk/keyvault/assets.json b/sdk/keyvault/assets.json index debd05b5d2..595344126c 100644 --- a/sdk/keyvault/assets.json +++ b/sdk/keyvault/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "rust", "TagPrefix": "rust/keyvault", - "Tag": "rust/keyvault_7e75eabce0" + "Tag": "rust/keyvault_5961c5368d" } diff --git a/sdk/keyvault/azure_security_keyvault_certificates/tests/certificate_client.rs b/sdk/keyvault/azure_security_keyvault_certificates/tests/certificate_client.rs index 4e33ce9273..44fcaf338b 100644 --- a/sdk/keyvault/azure_security_keyvault_certificates/tests/certificate_client.rs +++ b/sdk/keyvault/azure_security_keyvault_certificates/tests/certificate_client.rs @@ -25,7 +25,7 @@ use azure_security_keyvault_keys::{ KeyClient, KeyClientOptions, }; use azure_security_keyvault_test::Retry; -use futures::{FutureExt, TryStreamExt}; +use futures::TryStreamExt; use openssl::sha::sha256; use std::{collections::HashMap, sync::LazyLock}; @@ -93,24 +93,21 @@ async fn certificate_validate_instrumentation(ctx: TestContext) -> Result<()> { )?; Ok(client) }, - |client| { - async move { - // Create a self-signed certificate. - let body = CreateCertificateParameters { - certificate_policy: Some(DEFAULT_CERTIFICATE_POLICY.clone()), - ..Default::default() - }; - let _certificate = client - .create_certificate( - "certificate-validate-instrumentation", - body.try_into()?, - None, - )? - .await? - .into_model()?; - Ok(()) - } - .boxed() + async move |client| { + // Create a self-signed certificate. + let body = CreateCertificateParameters { + certificate_policy: Some(DEFAULT_CERTIFICATE_POLICY.clone()), + ..Default::default() + }; + let _certificate = client + .create_certificate( + "certificate-validate-instrumentation", + body.try_into()?, + None, + )? + .await? + .into_model()?; + Ok(()) }, ExpectedInstrumentation { package_name: recording.var("CARGO_PKG_NAME", None), @@ -122,14 +119,12 @@ async fn certificate_validate_instrumentation(ctx: TestContext) -> Result<()> { ExpectedRestApiSpan { api_verb: Method::Post, expected_status_code: StatusCode::Accepted, + is_wildcard: false, }, ExpectedRestApiSpan { api_verb: Method::Get, expected_status_code: StatusCode::Ok, - }, - ExpectedRestApiSpan { - api_verb: Method::Get, - expected_status_code: StatusCode::Ok, + is_wildcard: true, }, ], ..Default::default() diff --git a/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs b/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs index e1b531e063..7d24a67276 100644 --- a/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs +++ b/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs @@ -268,6 +268,8 @@ async fn round_trip_secret_verify_telemetry(ctx: TestContext) -> Result<()> { }, ExpectedInstrumentation { package_name: recording.var("CARGO_PKG_NAME", None), + + // Don't use `recording.var` here in case the recording was made with a different package version. package_version: env!("CARGO_PKG_VERSION").into(), package_namespace: Some("KeyVault"), api_calls: vec![ @@ -342,44 +344,109 @@ async fn list_secrets_verify_telemetry(ctx: TestContext) -> Result<()> { Some(options), ) }, - |client: SecretClient| { - Box::pin(async move { - let mut secrets = client.list_secret_properties(None)?; - while let Some(secret) = secrets.try_next().await? { - let _ = secret.resource_id()?; - } + async move |client: SecretClient| { + let mut secrets = client.list_secret_properties(None)?; + while let Some(secret) = secrets.try_next().await? { + let _ = secret.resource_id()?; + } - Ok(()) - }) + Ok(()) }, ExpectedInstrumentation { package_name: recording.var("CARGO_PKG_NAME", None), + // Don't use `recording.var` here in case the recording was made with a different package version. package_version: env!("CARGO_PKG_VERSION").into(), package_namespace: Some("KeyVault"), api_calls: vec![ExpectedApiInformation { api_name: Some("KeyVault.getSecrets"), - api_children: vec![ - ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, - ..Default::default() - }, - ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, - ..Default::default() - }, - ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, - ..Default::default() - }, - ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, - ..Default::default() - }, - ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, + api_children: vec![ExpectedRestApiSpan { + api_verb: azure_core::http::Method::Get, + is_wildcard: true, + ..Default::default() + }], + ..Default::default() + }], + }, + ) + .await; + + validate_result +} + +#[recorded::test] +async fn list_secrets_by_pages_verify_telemetry(ctx: TestContext) -> Result<()> { + use azure_core_test::tracing::ExpectedRestApiSpan; + + const SECRET_COUNT: usize = 50; + + let recording = ctx.recording(); + + { + let secret_client = { + let mut options = SecretClientOptions::default(); + recording.instrument(&mut options.client_options); + SecretClient::new( + recording.var("AZURE_KEYVAULT_URL", None).as_str(), + recording.credential(), + Some(options), + ) + }?; + for i in 0..SECRET_COUNT { + let secret = secret_client + .set_secret( + &format!("secret-list-telemetry-by-page{}", i), + SetSecretParameters { + value: Some(format!("secret-list-telemetry-by-page-value-{}", i)), ..Default::default() - }, - ], + } + .try_into()?, + None, + ) + .await? + .into_model()?; + assert_eq!( + secret.value, + Some(format!("secret-list-telemetry-by-page-value-{}", i)) + ); + } + } + // Verify that the distributed tracing traces generated from the API call below match the expected traces. + let validate_result = azure_core_test::tracing::assert_instrumentation_information( + |tracer_provider| { + let mut options = SecretClientOptions::default(); + recording.instrument(&mut options.client_options); + options.client_options.instrumentation = InstrumentationOptions { + tracer_provider: Some(tracer_provider), + }; + SecretClient::new( + recording.var("AZURE_KEYVAULT_URL", None).as_str(), + recording.credential(), + Some(options), + ) + }, + async move |client: SecretClient| { + let mut secrets = client.list_secret_properties(None)?.into_pages(); + while let Some(page) = secrets.try_next().await? { + let items = page.into_model()?; + for item in items.value { + let _ = item.resource_id()?; + } + } + + Ok(()) + }, + ExpectedInstrumentation { + package_name: recording.var("CARGO_PKG_NAME", None), + // Don't use `recording.var` here in case the recording was made with a different package version. + package_version: env!("CARGO_PKG_VERSION").into(), + package_namespace: Some("KeyVault"), + api_calls: vec![ExpectedApiInformation { + api_name: Some("KeyVault.getSecrets"), + api_children: vec![ExpectedRestApiSpan { + api_verb: azure_core::http::Method::Get, + is_wildcard: true, + ..Default::default() + }], ..Default::default() }], }, @@ -390,7 +457,6 @@ async fn list_secrets_verify_telemetry(ctx: TestContext) -> Result<()> { } #[recorded::test] -#[ignore = "Test does not currently work because instrumentation of PageIterators doesn't quite work."] async fn list_secrets_verify_telemetry_rehydrated(ctx: TestContext) -> Result<()> { use azure_core_test::tracing::ExpectedRestApiSpan; @@ -441,8 +507,8 @@ async fn list_secrets_verify_telemetry_rehydrated(ctx: TestContext) -> Result<() Some(options), ) }, - |client: SecretClient| { - Box::pin(async move { + async move |client: SecretClient| { + let rehydration_token = { let mut first_pager = client.list_secret_properties(None)?.into_pages(); // Prime the iteration. @@ -457,28 +523,28 @@ async fn list_secrets_verify_telemetry_rehydrated(ctx: TestContext) -> Result<() } } - let rehydration_token = first_pager + first_pager .continuation_token() - .expect("expected continuation token to be created after first page"); - - let mut rehydrated_pager = client - .list_secret_properties(None)? - .into_pages() - .with_continuation_token(rehydration_token); - - while let Some(secret_page) = rehydrated_pager.try_next().await? { - let secrets = secret_page.into_model()?; - for secret in secrets.value { - let _ = secret.resource_id()?; - } + .expect("expected continuation token to be created after first page") + }; + let mut rehydrated_pager = client + .list_secret_properties(None)? + .into_pages() + .with_continuation_token(rehydration_token); + + while let Some(secret_page) = rehydrated_pager.try_next().await? { + let secrets = secret_page.into_model()?; + for secret in secrets.value { + let _ = secret.resource_id()?; } + } - Ok(()) - }) + Ok(()) }, ExpectedInstrumentation { package_name: recording.var("CARGO_PKG_NAME", None), - package_version: recording.var("CARGO_PKG_VERSION", None), + // Don't use `recording.var` here in case the recording was made with a different package version. + package_version: env!("CARGO_PKG_VERSION").into(), package_namespace: Some("KeyVault"), api_calls: vec![ ExpectedApiInformation { @@ -493,30 +559,7 @@ async fn list_secrets_verify_telemetry_rehydrated(ctx: TestContext) -> Result<() api_name: Some("KeyVault.getSecrets"), api_children: vec![ExpectedRestApiSpan { api_verb: azure_core::http::Method::Get, - ..Default::default() - }], - ..Default::default() - }, - ExpectedApiInformation { - api_name: Some("KeyVault.getSecrets"), - api_children: vec![ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, - ..Default::default() - }], - ..Default::default() - }, - ExpectedApiInformation { - api_name: Some("KeyVault.getSecrets"), - api_children: vec![ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, - ..Default::default() - }], - ..Default::default() - }, - ExpectedApiInformation { - api_name: Some("KeyVault.getSecrets"), - api_children: vec![ExpectedRestApiSpan { - api_verb: azure_core::http::Method::Get, + is_wildcard: true, ..Default::default() }], ..Default::default()