Skip to content

Commit 4f22f96

Browse files
committed
feat(client): add some general HTTP/1 client middleware
1 parent b9dc3d2 commit 4f22f96

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

src/client/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/// Legacy implementations of `connect` module and `Client`
44
#[cfg(feature = "client-legacy")]
55
pub mod legacy;
6+
pub mod service;
67

78
#[cfg(feature = "client-proxy")]
89
pub mod proxy;

src/client/service.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
//! todo
2+
3+
use std::task::{Context, Poll};
4+
5+
use http::header::{HeaderValue, HOST};
6+
use http::{Method, Request, Uri};
7+
use tower_service::Service;
8+
9+
/// todo
10+
#[derive(Clone, Debug)]
11+
pub struct SetHost<S> {
12+
inner: S,
13+
}
14+
15+
/// todo
16+
pub struct Http1RequestTarget<S> {
17+
inner: S,
18+
}
19+
20+
// ===== impl SetHost =====
21+
22+
impl<S> SetHost<S> {
23+
/// todo
24+
pub fn new(inner: S) -> Self {
25+
SetHost { inner }
26+
}
27+
}
28+
29+
impl<S, ReqBody> Service<Request<ReqBody>> for SetHost<S>
30+
where
31+
S: Service<Request<ReqBody>>,
32+
{
33+
type Response = S::Response;
34+
type Error = S::Error;
35+
type Future = S::Future;
36+
37+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38+
self.inner.poll_ready(cx)
39+
}
40+
41+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
42+
if req.uri().authority().is_some() {
43+
let uri = req.uri().clone();
44+
req.headers_mut().entry(HOST).or_insert_with(|| {
45+
let hostname = uri.host().expect("authority implies host");
46+
if let Some(port) = get_non_default_port(&uri) {
47+
let s = format!("{hostname}:{port}");
48+
HeaderValue::from_str(&s)
49+
} else {
50+
HeaderValue::from_str(hostname)
51+
}
52+
.expect("uri host is valid header value")
53+
});
54+
}
55+
self.inner.call(req)
56+
}
57+
}
58+
59+
fn get_non_default_port(uri: &Uri) -> Option<http::uri::Port<&str>> {
60+
match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) {
61+
(Some(443), true) => None,
62+
(Some(80), false) => None,
63+
_ => uri.port(),
64+
}
65+
}
66+
67+
fn is_schema_secure(uri: &Uri) -> bool {
68+
uri.scheme_str()
69+
.map(|scheme_str| matches!(scheme_str, "wss" | "https"))
70+
.unwrap_or_default()
71+
}
72+
73+
// ===== impl Http1RequestTarget =====
74+
75+
impl<S> Http1RequestTarget<S> {
76+
/// todo
77+
pub fn new(inner: S) -> Self {
78+
Http1RequestTarget { inner }
79+
}
80+
}
81+
82+
impl<S, ReqBody> Service<Request<ReqBody>> for Http1RequestTarget<S>
83+
where
84+
S: Service<Request<ReqBody>>,
85+
{
86+
type Response = S::Response;
87+
type Error = S::Error;
88+
type Future = S::Future;
89+
90+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91+
self.inner.poll_ready(cx)
92+
}
93+
94+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
95+
// CONNECT always sends authority-form, so check it first...
96+
if req.method() == Method::CONNECT {
97+
authority_form(req.uri_mut());
98+
} else {
99+
origin_form(req.uri_mut());
100+
}
101+
self.inner.call(req)
102+
}
103+
}
104+
105+
fn origin_form(uri: &mut Uri) {
106+
let path = match uri.path_and_query() {
107+
Some(path) if path.as_str() != "/" => {
108+
let mut parts = ::http::uri::Parts::default();
109+
parts.path_and_query = Some(path.clone());
110+
Uri::from_parts(parts).expect("path is valid uri")
111+
}
112+
_none_or_just_slash => {
113+
debug_assert!(Uri::default() == "/");
114+
Uri::default()
115+
}
116+
};
117+
*uri = path
118+
}
119+
120+
fn authority_form(uri: &mut Uri) {
121+
if let Some(path) = uri.path_and_query() {
122+
// `https://hyper.rs` would parse with `/` path, don't
123+
// annoy people about that...
124+
if path != "/" {
125+
tracing::debug!("HTTP/1.1 CONNECT request stripping path: {:?}", path);
126+
}
127+
}
128+
*uri = match uri.authority() {
129+
Some(auth) => {
130+
let mut parts = ::http::uri::Parts::default();
131+
parts.authority = Some(auth.clone());
132+
Uri::from_parts(parts).expect("authority is valid")
133+
}
134+
None => {
135+
unreachable!("authority_form with relative uri");
136+
}
137+
};
138+
}

0 commit comments

Comments
 (0)