Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/core/azure_core/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub use response::{AsyncRawResponse, AsyncResponse, RawResponse, Response};
pub use typespec_client_core::http::response;
pub use typespec_client_core::http::{
new_http_client, AppendToUrlQuery, Context, DeserializeWith, Format, HttpClient, JsonFormat,
Method, NoFormat, StatusCode, Url, UrlExt,
Method, NoFormat, Sanitizer, StatusCode, Url, UrlExt,
};

pub use crate::error::check_success;
Expand Down
50 changes: 39 additions & 11 deletions sdk/core/azure_core/src/http/pager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,33 @@ impl<P> fmt::Debug for PageIterator<P> {
}
}

#[derive(Debug, Clone, Eq)]
enum State<C> {
#[derive(Clone, Eq)]
enum State<C>
where
C: AsRef<str>,
{
Init,
More(C),
Done,
}

impl<C> PartialEq for State<C> {
impl<C> fmt::Debug for State<C>
where
C: AsRef<str>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
State::Init => write!(f, "Init"),
State::More(c) => f.debug_tuple("More").field(&c.as_ref()).finish(),
State::Done => write!(f, "Done"),
}
}
}

impl<C> PartialEq for State<C>
where
C: AsRef<str>,
{
fn eq(&self, other: &Self) -> bool {
// Only needs to compare if both states are Init or Done; internally, we don't care about any other states.
matches!(
Expand All @@ -829,7 +848,10 @@ impl<C> PartialEq for State<C> {
}

#[derive(Debug)]
struct StreamState<'a, C, F> {
struct StreamState<'a, C, F>
where
C: AsRef<str>,
{
state: State<C>,
make_request: F,
continuation_token: Arc<Mutex<Option<String>>>,
Expand Down Expand Up @@ -863,9 +885,21 @@ where
added_span: false,
},
|mut stream_state| async move {
// When in the "Init" state, we are either starting fresh or resuming from a continuation token. In either case,
// attach a span to the context for the entire paging operation.
if stream_state.state == State::Init {
tracing::debug!("establish a public API span for new pager.");

// At the very start of polling, create a span for the entire request, and attach it to the context
let span = create_public_api_span(&stream_state.ctx, None, None);
if let Some(ref s) = span {
stream_state.added_span = true;
stream_state.ctx = stream_state.ctx.with_value(s.clone());
}
}

// Get the `continuation_token` to pick up where we left off, or None for the initial page,
// but don't override the terminal `State::Done`.

if stream_state.state != State::Done {
let result = match stream_state.continuation_token.lock() {
Ok(next_token) => match next_token.as_deref() {
Expand Down Expand Up @@ -895,12 +929,6 @@ where
let result = match stream_state.state {
State::Init => {
tracing::debug!("initial page request");
// At the very start of polling, create a span for the entire request, and attach it to the context
let span = create_public_api_span(&stream_state.ctx, None, None);
if let Some(ref s) = span {
stream_state.added_span = true;
stream_state.ctx = stream_state.ctx.with_value(s.clone());
}
(stream_state.make_request)(PagerState::Initial, stream_state.ctx.clone()).await
}
State::More(n) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,7 @@ mod tests {
status: SpanStatus::Unset,
kind: SpanKind::Internal,
span_id: Uuid::new_v4(),
parent_id: None,
attributes: vec![],
..Default::default()
}],
}],
)
Expand Down Expand Up @@ -599,9 +598,9 @@ mod tests {
span_name: "MyClient.MyApi",
status: SpanStatus::Unset,
span_id: Uuid::new_v4(),
parent_id: None,
kind: SpanKind::Internal,
attributes: vec![(AZ_NAMESPACE_ATTRIBUTE, "test namespace".into())],
..Default::default()
}],
}],
);
Expand Down Expand Up @@ -652,6 +651,7 @@ mod tests {
(AZ_NAMESPACE_ATTRIBUTE, "test namespace".into()),
(ERROR_TYPE_ATTRIBUTE, "500".into()),
],
..Default::default()
}],
}],
);
Expand Down Expand Up @@ -701,6 +701,7 @@ mod tests {
(AZ_NAMESPACE_ATTRIBUTE, "test.namespace".into()),
("az.fake_attribute", "attribute value".into()),
],
..Default::default()
},
ExpectedSpanInformation {
span_name: "PUT",
Expand All @@ -716,6 +717,7 @@ mod tests {
("server.port", 80.into()),
("http.response.status_code", 200.into()),
],
..Default::default()
},
],
}],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ pub(crate) mod tests {
),
),
],
..Default::default()
}],
}],
);
Expand Down Expand Up @@ -388,6 +389,7 @@ pub(crate) mod tests {
AttributeValue::from("https://example.com/client_request_id"),
),
],
..Default::default()
}],
}],
);
Expand Down Expand Up @@ -433,6 +435,7 @@ pub(crate) mod tests {
(SERVER_ADDRESS_ATTRIBUTE, AttributeValue::from("host")),
(SERVER_PORT_ATTRIBUTE, AttributeValue::from(8080)),
],
..Default::default()
}],
}],
);
Expand Down Expand Up @@ -502,6 +505,7 @@ pub(crate) mod tests {
AttributeValue::from("https://microsoft.com/request_failed.htm"),
),
],
..Default::default()
}],
}],
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,7 @@ mod tests {
let package_version = env!("CARGO_PKG_VERSION").to_string();
azure_core_test::tracing::assert_instrumentation_information(
|tracer_provider| Ok(create_service_client(&ctx, tracer_provider)),
|client| {
let client = client;
Box::pin(async move { client.get("get", None).await })
},
async move |client| client.get("get", None).await,
ExpectedInstrumentation {
package_name,
package_version,
Expand All @@ -706,14 +703,11 @@ mod tests {

#[recorded::test()]
async fn test_function_tracing_tests(ctx: TestContext) -> Result<()> {
let package_name = env!("CARGO_PKG_NAME").to_string();
let package_name = ctx.recording().var("CARGO_PKG_NAME", None).to_string();
let package_version = env!("CARGO_PKG_VERSION").to_string();
azure_core_test::tracing::assert_instrumentation_information(
|tracer_provider| Ok(create_service_client(&ctx, tracer_provider)),
|client| {
let client = client;
Box::pin(async move { client.get_with_function_tracing("get", None).await })
},
async move |client| client.get_with_function_tracing("get", None).await,
ExpectedInstrumentation {
package_name,
package_version,
Expand All @@ -737,14 +731,11 @@ mod tests {
async fn test_function_tracing_tests_error(ctx: TestContext) -> Result<()> {
use azure_core_test::tracing::ExpectedRestApiSpan;

let package_name = env!("CARGO_PKG_NAME").to_string();
let package_name = ctx.recording().var("CARGO_PKG_NAME", None).to_string();
let package_version = env!("CARGO_PKG_VERSION").to_string();
azure_core_test::tracing::assert_instrumentation_information(
|tracer_provider| Ok(create_service_client(&ctx, tracer_provider)),
|client| {
let client = client;
Box::pin(async move { client.get_with_function_tracing("index.htm", None).await })
},
async move |client| client.get_with_function_tracing("index.htm", None).await,
ExpectedInstrumentation {
package_name,
package_version,
Expand Down
Loading