Skip to content

Commit 66a132a

Browse files
committed
easier wsupgrade api for protocols/exts, ability for client to verify protocols, etc.
1 parent c828d9f commit 66a132a

File tree

7 files changed

+109
-50
lines changed

7 files changed

+109
-50
lines changed

examples/hyper.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ fn main() {
3838
return;
3939
}
4040

41-
// TODO: same check like in server.rs
42-
let mut client = connection.accept().unwrap();
41+
let mut client = connection.use_protocol("rust-websocket").accept().unwrap();
4342

4443
let ip = client.peer_addr().unwrap();
4544

examples/server.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@ use std::thread;
44
use websocket::{Server, Message};
55
use websocket::message::Type;
66

7-
// TODO: I think the .reject() call is only for malformed packets
8-
// there should be an easy way to accept the socket with the given protocols
9-
// this would mean there should be a way to accept or reject on the client
10-
// Do you send the protocol you want to talk when you are not given it as an
11-
// option? What is a rejection response? Does the client check for it?
12-
// Client should expose what the decided protocols/extensions/etc are.
13-
// can you accept only one protocol??
14-
157
fn main() {
168
let server = Server::bind("127.0.0.1:2794").unwrap();
179

@@ -23,7 +15,7 @@ fn main() {
2315
return;
2416
}
2517

26-
let mut client = request.accept().unwrap();
18+
let mut client = request.use_protocol("rust-websocket").accept().unwrap();
2719

2820
let ip = client.peer_addr().unwrap();
2921

src/client/builder.rs

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,6 @@ use result::{WSUrlErrorKind, WebSocketResult, WebSocketError};
2020
use stream::Stream;
2121
use super::Client;
2222

23-
macro_rules! upsert_header {
24-
($headers:expr; $header:ty; {
25-
Some($pat:pat) => $some_match:expr,
26-
None => $default:expr
27-
}) => {{
28-
match $headers.has::<$header>() {
29-
true => {
30-
match $headers.get_mut::<$header>() {
31-
Some($pat) => { $some_match; },
32-
None => (),
33-
};
34-
}
35-
false => {
36-
$headers.set($default);
37-
},
38-
};
39-
}}
40-
}
41-
4223
/// Build clients with a builder-style API
4324
#[derive(Clone, Debug)]
4425
pub struct ClientBuilder<'u> {
@@ -305,6 +286,6 @@ impl<'u> ClientBuilder<'u> {
305286
return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'"));
306287
}
307288

308-
Ok(Client::unchecked(stream))
289+
Ok(Client::unchecked(stream, response.headers))
309290
}
310291
}

src/client/mod.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ extern crate url;
44
use std::net::TcpStream;
55
use std::net::SocketAddr;
66
use std::io::Result as IoResult;
7+
use hyper::header::Headers;
78

89
use ws;
910
use ws::sender::Sender as SenderTrait;
@@ -12,6 +13,8 @@ use ws::receiver::Receiver as ReceiverTrait;
1213
use result::WebSocketResult;
1314
use stream::{AsTcpStream, Stream, Splittable, Shutdown};
1415
use dataframe::DataFrame;
16+
use header::{WebSocketProtocol, WebSocketExtensions, Origin};
17+
use header::extensions::Extension;
1518

1619
use ws::dataframe::DataFrame as DataFrameable;
1720
use sender::Sender;
@@ -54,6 +57,7 @@ pub struct Client<S>
5457
where S: Stream
5558
{
5659
pub stream: S,
60+
headers: Headers,
5761
sender: Sender,
5862
receiver: Receiver,
5963
}
@@ -109,8 +113,9 @@ impl<S> Client<S>
109113
/// **without sending any handshake** this is meant to only be used with
110114
/// a stream that has a websocket connection already set up.
111115
/// If in doubt, don't use this!
112-
pub fn unchecked(stream: S) -> Self {
116+
pub fn unchecked(stream: S, headers: Headers) -> Self {
113117
Client {
118+
headers: headers,
114119
stream: stream,
115120
// NOTE: these are always true & false, see
116121
// https://tools.ietf.org/html/rfc6455#section-5
@@ -152,6 +157,28 @@ impl<S> Client<S>
152157
self.receiver.recv_message(self.stream.reader())
153158
}
154159

160+
pub fn headers(&self) -> &Headers {
161+
&self.headers
162+
}
163+
164+
pub fn protocols(&self) -> &[String] {
165+
self.headers
166+
.get::<WebSocketProtocol>()
167+
.map(|p| p.0.as_slice())
168+
.unwrap_or(&[])
169+
}
170+
171+
pub fn extensions(&self) -> &[Extension] {
172+
self.headers
173+
.get::<WebSocketExtensions>()
174+
.map(|e| e.0.as_slice())
175+
.unwrap_or(&[])
176+
}
177+
178+
pub fn origin(&self) -> Option<&str> {
179+
self.headers.get::<Origin>().map(|o| &o.0 as &str)
180+
}
181+
155182
pub fn stream_ref(&self) -> &S {
156183
&self.stream
157184
}
@@ -160,6 +187,10 @@ impl<S> Client<S>
160187
&mut self.stream
161188
}
162189

190+
pub fn into_stream(self) -> S {
191+
self.stream
192+
}
193+
163194
/// Returns an iterator over incoming messages.
164195
///
165196
///```no_run

src/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ pub use self::stream::Stream;
6060
pub use self::ws::Sender;
6161
pub use self::ws::Receiver;
6262

63+
macro_rules! upsert_header {
64+
($headers:expr; $header:ty; {
65+
Some($pat:pat) => $some_match:expr,
66+
None => $default:expr
67+
}) => {{
68+
match $headers.has::<$header>() {
69+
true => {
70+
match $headers.get_mut::<$header>() {
71+
Some($pat) => { $some_match; },
72+
None => (),
73+
};
74+
}
75+
false => {
76+
$headers.set($default);
77+
},
78+
};
79+
}}
80+
}
81+
82+
6383
pub mod ws;
6484
pub mod client;
6585
pub mod server;

src/server/upgrade/hyper.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> {
3030
let stream = reader.into_inner().get_mut();
3131

3232
Ok(WsUpgrade {
33+
headers: Headers::new(),
3334
stream: stream,
3435
request: Incoming {
3536
version: version,

src/server/upgrade/mod.rs

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,41 +37,74 @@ pub mod hyper;
3737
pub struct WsUpgrade<S>
3838
where S: Stream
3939
{
40-
stream: S,
41-
request: Request,
40+
pub headers: Headers,
41+
pub stream: S,
42+
pub request: Request,
4243
}
4344

4445
impl<S> WsUpgrade<S>
4546
where S: Stream
4647
{
47-
pub fn accept(self) -> IoResult<Client<S>> {
48+
pub fn use_protocol<P>(mut self, protocol: P) -> Self
49+
where P: Into<String>
50+
{
51+
upsert_header!(self.headers; WebSocketProtocol; {
52+
Some(protos) => protos.0.push(protocol.into()),
53+
None => WebSocketProtocol(vec![protocol.into()])
54+
});
55+
self
56+
}
57+
58+
pub fn use_extension(mut self, extension: Extension) -> Self {
59+
upsert_header!(self.headers; WebSocketExtensions; {
60+
Some(protos) => protos.0.push(extension),
61+
None => WebSocketExtensions(vec![extension])
62+
});
63+
self
64+
}
65+
66+
pub fn use_extensions<I>(mut self, extensions: I) -> Self
67+
where I: IntoIterator<Item = Extension>
68+
{
69+
let mut extensions: Vec<Extension> =
70+
extensions.into_iter().collect();
71+
upsert_header!(self.headers; WebSocketExtensions; {
72+
Some(protos) => protos.0.append(&mut extensions),
73+
None => WebSocketExtensions(extensions)
74+
});
75+
self
76+
}
77+
78+
pub fn accept(self) -> Result<Client<S>, (S, IoError)> {
4879
self.accept_with(&Headers::new())
4980
}
5081

51-
pub fn accept_with(mut self, custom_headers: &Headers) -> IoResult<Client<S>> {
52-
let mut headers = Headers::new();
53-
headers.extend(custom_headers.iter());
54-
headers.set(WebSocketAccept::new(
55-
// NOTE: we know there is a key because this is a valid request
56-
// i.e. to construct this you must go through the validate function
57-
self.request.headers.get::<WebSocketKey>().unwrap()
58-
));
59-
headers.set(Connection(vec![
82+
pub fn accept_with(mut self, custom_headers: &Headers) -> Result<Client<S>, (S, IoError)> {
83+
self.headers.extend(custom_headers.iter());
84+
self.headers
85+
.set(WebSocketAccept::new(// NOTE: we know there is a key because this is a valid request
86+
// i.e. to construct this you must go through the validate function
87+
self.request.headers.get::<WebSocketKey>().unwrap()));
88+
self.headers
89+
.set(Connection(vec![
6090
ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string()))
6191
]));
62-
headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)]));
92+
self.headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)]));
6393

64-
try!(self.send(StatusCode::SwitchingProtocols, &headers));
94+
if let Err(e) = self.send(StatusCode::SwitchingProtocols) {
95+
return Err((self.stream, e));
96+
}
6597

66-
Ok(Client::unchecked(self.stream))
98+
Ok(Client::unchecked(self.stream, self.headers))
6799
}
68100

69101
pub fn reject(self) -> Result<S, (S, IoError)> {
70102
self.reject_with(&Headers::new())
71103
}
72104

73105
pub fn reject_with(mut self, headers: &Headers) -> Result<S, (S, IoError)> {
74-
match self.send(StatusCode::BadRequest, headers) {
106+
self.headers.extend(headers.iter());
107+
match self.send(StatusCode::BadRequest) {
75108
Ok(()) => Ok(self.stream),
76109
Err(e) => Err((self.stream, e)),
77110
}
@@ -113,12 +146,12 @@ impl<S> WsUpgrade<S>
113146
self.stream
114147
}
115148

116-
fn send(&mut self, status: StatusCode, headers: &Headers) -> IoResult<()> {
149+
fn send(&mut self, status: StatusCode) -> IoResult<()> {
117150
try!(write!(self.stream.writer(),
118151
"{} {}\r\n",
119152
self.request.version,
120153
status));
121-
try!(write!(self.stream.writer(), "{}\r\n", headers));
154+
try!(write!(self.stream.writer(), "{}\r\n", self.headers));
122155
Ok(())
123156
}
124157
}
@@ -173,6 +206,7 @@ impl<S> IntoWs for S
173206
match validate(&request.subject.0, &request.version, &request.headers) {
174207
Ok(_) => {
175208
Ok(WsUpgrade {
209+
headers: Headers::new(),
176210
stream: self,
177211
request: request,
178212
})
@@ -192,6 +226,7 @@ impl<S> IntoWs for RequestStreamPair<S>
192226
match validate(&self.1.subject.0, &self.1.version, &self.1.headers) {
193227
Ok(_) => {
194228
Ok(WsUpgrade {
229+
headers: Headers::new(),
195230
stream: self.0,
196231
request: self.1,
197232
})

0 commit comments

Comments
 (0)