Skip to content

Commit 7b0a8fc

Browse files
committed
feat: custom header support
Add true support for custom headers with full backwards compatibility.
1 parent d96982d commit 7b0a8fc

File tree

7 files changed

+151
-51
lines changed

7 files changed

+151
-51
lines changed

examples/custom_header.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use serde::{Deserialize, Serialize};
22
use std::collections::HashMap;
33

44
use jsonwebtoken::errors::ErrorKind;
5-
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
5+
use jsonwebtoken::{
6+
Algorithm, DecodingKey, EncodingKey, Validation, decode_with_custom_header, encode, header,
7+
};
68

79
#[derive(Debug, Serialize, Deserialize, Clone)]
810
struct Claims {
@@ -11,6 +13,19 @@ struct Claims {
1113
exp: u64,
1214
}
1315

16+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
17+
struct CustomHeader {
18+
alg: Algorithm,
19+
custom: String,
20+
another_custom_field: Option<usize>,
21+
}
22+
impl header::FromEncoded for CustomHeader {}
23+
impl header::Alg for CustomHeader {
24+
fn alg(&self) -> &Algorithm {
25+
&self.alg
26+
}
27+
}
28+
1429
fn main() {
1530
let my_claims =
1631
Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), exp: 10000000000 };
@@ -19,11 +34,10 @@ fn main() {
1934
let mut extras = HashMap::with_capacity(1);
2035
extras.insert("custom".to_string(), "header".to_string());
2136

22-
let header = Header {
23-
kid: Some("signing_key".to_owned()),
37+
let header = CustomHeader {
2438
alg: Algorithm::HS512,
25-
extras,
26-
..Default::default()
39+
custom: "custom".into(),
40+
another_custom_field: 42.into(),
2741
};
2842

2943
let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) {
@@ -32,7 +46,7 @@ fn main() {
3246
};
3347
println!("{:?}", token);
3448

35-
let token_data = match decode::<Claims>(
49+
let token_data = match decode_with_custom_header::<CustomHeader, Claims>(
3650
&token,
3751
&DecodingKey::from_secret(key),
3852
&Validation::new(Algorithm::HS512),

src/decoding.rs

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::Algorithm;
77
use crate::algorithms::AlgorithmFamily;
88
use crate::crypto::JwtVerifier;
99
use crate::errors::{ErrorKind, Result, new_error};
10-
use crate::header::Header;
10+
use crate::header::{Alg, FromEncoded, Header};
1111
use crate::jwk::{AlgorithmParameters, Jwk};
1212
#[cfg(feature = "use_pem")]
1313
use crate::pem::decoder::PemEncodedKey;
@@ -37,15 +37,16 @@ use crate::crypto::rust_crypto::{
3737

3838
/// The return type of a successful call to [decode](fn.decode.html).
3939
#[derive(Debug)]
40-
pub struct TokenData<T> {
40+
pub struct TokenData<H, T> {
4141
/// The decoded JWT header
42-
pub header: Header,
42+
pub header: H,
4343
/// The decoded JWT claims
4444
pub claims: T,
4545
}
4646

47-
impl<T> Clone for TokenData<T>
47+
impl<H, T> Clone for TokenData<H, T>
4848
where
49+
H: Clone,
4950
T: Clone,
5051
{
5152
fn clone(&self) -> Self {
@@ -281,21 +282,40 @@ pub fn decode<T: DeserializeOwned + Clone>(
281282
token: impl AsRef<[u8]>,
282283
key: &DecodingKey,
283284
validation: &Validation,
284-
) -> Result<TokenData<T>> {
285+
) -> Result<TokenData<Header, T>> {
286+
decode_with_custom_header(token, key, validation)
287+
}
288+
289+
/// Decode and validate a JWT with a custom header
290+
///
291+
/// If the token or its signature is invalid, or the claims fail validation, this will return an
292+
/// error.
293+
pub fn decode_with_custom_header<H, T>(
294+
token: impl AsRef<[u8]>,
295+
key: &DecodingKey,
296+
validation: &Validation,
297+
) -> Result<TokenData<H, T>>
298+
where
299+
H: DeserializeOwned + Clone + FromEncoded + Alg,
300+
T: DeserializeOwned + Clone,
301+
{
285302
let token = token.as_ref();
286-
let header = decode_header(token)?;
287303

288-
if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
304+
let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
305+
let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
306+
let header = H::from_encoded(header)?;
307+
308+
if validation.validate_signature && !validation.algorithms.contains(header.alg()) {
289309
return Err(new_error(ErrorKind::InvalidAlgorithm));
290310
}
291311

292-
let verifying_provider = jwt_verifier_factory(&header.alg, key)?;
312+
let verifying_provider = jwt_verifier_factory(header.alg(), key)?;
313+
verify_signature_body(message, signature, &header, validation, verifying_provider)?;
293314

294-
let (header, claims) = verify_signature(token, validation, verifying_provider)?;
315+
let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(payload)?;
316+
validate(decoded_claims.deserialize()?, validation)?;
295317

296-
let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?;
297318
let claims = decoded_claims.deserialize()?;
298-
validate(decoded_claims.deserialize()?, validation)?;
299319

300320
Ok(TokenData { header, claims })
301321
}
@@ -357,10 +377,21 @@ pub fn decode_header(token: impl AsRef<[u8]>) -> Result<Header> {
357377
Header::from_encoded(header)
358378
}
359379

380+
/// Decode only the custom header of a JWT without decoding or validating the payload
381+
pub fn decode_custom_header<H>(token: impl AsRef<[u8]>) -> Result<H>
382+
where
383+
H: DeserializeOwned + Clone + Alg + FromEncoded,
384+
{
385+
let token = token.as_ref();
386+
let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
387+
let (_, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
388+
H::from_encoded(header)
389+
}
390+
360391
pub(crate) fn verify_signature_body(
361392
message: &[u8],
362393
signature: &[u8],
363-
header: &Header,
394+
header: &impl Alg,
364395
validation: &Validation,
365396
verifying_provider: Box<dyn JwtVerifier>,
366397
) -> Result<()> {
@@ -376,7 +407,7 @@ pub(crate) fn verify_signature_body(
376407
}
377408
}
378409

379-
if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
410+
if validation.validate_signature && !validation.algorithms.contains(header.alg()) {
380411
return Err(new_error(ErrorKind::InvalidAlgorithm));
381412
}
382413

@@ -388,19 +419,3 @@ pub(crate) fn verify_signature_body(
388419

389420
Ok(())
390421
}
391-
392-
/// Verify the signature of a JWT, and return a header object and raw payload.
393-
///
394-
/// If the token or its signature is invalid, it will return an error.
395-
fn verify_signature<'a>(
396-
token: &'a [u8],
397-
validation: &Validation,
398-
verifying_provider: Box<dyn JwtVerifier>,
399-
) -> Result<(Header, &'a [u8])> {
400-
let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
401-
let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
402-
let header = Header::from_encoded(header)?;
403-
verify_signature_body(message, signature, &header, validation, verifying_provider)?;
404-
405-
Ok((header, payload))
406-
}

src/encoding.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::Algorithm;
1010
use crate::algorithms::AlgorithmFamily;
1111
use crate::crypto::JwtSigner;
1212
use crate::errors::{ErrorKind, Result, new_error};
13-
use crate::header::Header;
13+
use crate::header::Alg;
1414
#[cfg(feature = "use_pem")]
1515
use crate::pem::decoder::PemEncodedKey;
1616
use crate::serialization::{b64_encode, b64_encode_part};
@@ -171,14 +171,18 @@ impl Debug for EncodingKey {
171171
/// // This will create a JWT using HS256 as algorithm
172172
/// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap();
173173
/// ```
174-
pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> {
175-
if key.family != header.alg.family() {
174+
pub fn encode<H: Serialize + Alg, T: Serialize>(
175+
header: &H,
176+
claims: &T,
177+
key: &EncodingKey,
178+
) -> Result<String> {
179+
if key.family != header.alg().family() {
176180
return Err(new_error(ErrorKind::InvalidAlgorithm));
177181
}
178182

179-
let signing_provider = jwt_signer_factory(&header.alg, key)?;
183+
let signing_provider = jwt_signer_factory(header.alg(), key)?;
180184

181-
if signing_provider.algorithm() != header.alg {
185+
if signing_provider.algorithm() != *header.alg() {
182186
return Err(new_error(ErrorKind::InvalidAlgorithm));
183187
}
184188

src/header.rs

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
//! Traits and datastructures for JWT Headers
12
use std::collections::HashMap;
23
use std::result;
34

@@ -25,12 +26,19 @@ const ENC_A256GCM: &str = "A256GCM";
2526
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2627
#[allow(clippy::upper_case_acronyms, non_camel_case_types)]
2728
pub enum Enc {
29+
/// HMAC-256
2830
A128CBC_HS256,
31+
/// HMAC-384
2932
A192CBC_HS384,
33+
/// HMAC-512
3034
A256CBC_HS512,
35+
/// AES-GCM 128
3136
A128GCM,
37+
/// AES-GCM 192
3238
A192GCM,
39+
/// AES-GCM 256
3340
A256GCM,
41+
/// Other encryption type
3442
Other(String),
3543
}
3644

@@ -76,7 +84,9 @@ impl<'de> Deserialize<'de> for Enc {
7684
/// Defined in [RFC7516#4.1.3](https://datatracker.ietf.org/doc/html/rfc7516#section-4.1.3).
7785
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7886
pub enum Zip {
87+
/// Basic Deflate Compression
7988
Deflate,
89+
/// Other Compression
8090
Other(String),
8191
}
8292

@@ -106,6 +116,25 @@ impl<'de> Deserialize<'de> for Zip {
106116
}
107117
}
108118

119+
/// Getter for `alg` attribute of a JWT Header
120+
/// This must be implemented by custom header structs
121+
pub trait Alg {
122+
/// Getter for `alg`
123+
fn alg(&self) -> &Algorithm;
124+
}
125+
126+
/// Decodes a JWT part from b64
127+
pub trait FromEncoded {
128+
/// Converts an encoded JWT part into the Header struct if possible
129+
fn from_encoded<T: AsRef<[u8]>>(encoded_part: T) -> Result<Self>
130+
where
131+
Self: Sized + serde::de::DeserializeOwned,
132+
{
133+
let decoded = b64_decode(encoded_part)?;
134+
Ok(serde_json::from_slice(&decoded)?)
135+
}
136+
}
137+
109138
/// A basic JWT header, the alg defaults to HS256 and typ is automatically
110139
/// set to `JWT`. All the other fields are optional.
111140
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -213,12 +242,6 @@ impl Header {
213242
}
214243
}
215244

216-
/// Converts an encoded part into the Header struct if possible
217-
pub(crate) fn from_encoded<T: AsRef<[u8]>>(encoded_part: T) -> Result<Self> {
218-
let decoded = b64_decode(encoded_part)?;
219-
Ok(serde_json::from_slice(&decoded)?)
220-
}
221-
222245
/// Decodes the X.509 certificate chain into ASN.1 DER format.
223246
pub fn x5c_der(&self) -> Result<Option<Vec<Vec<u8>>>> {
224247
Ok(self
@@ -237,3 +260,11 @@ impl Default for Header {
237260
Header::new(Algorithm::default())
238261
}
239262
}
263+
264+
impl Alg for Header {
265+
fn alg(&self) -> &Algorithm {
266+
&self.alg
267+
}
268+
}
269+
270+
impl FromEncoded for Header {}

src/jws.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use crate::crypto::sign;
55
use crate::errors::{ErrorKind, Result, new_error};
66
use crate::serialization::{DecodedJwtPartClaims, b64_encode_part};
77
use crate::validation::validate;
8-
use crate::{DecodingKey, EncodingKey, Header, TokenData, Validation};
8+
use crate::{
9+
DecodingKey, EncodingKey, TokenData, Validation,
10+
header::{FromEncoded, Header},
11+
};
912

1013
use crate::decoding::{jwt_verifier_factory, verify_signature_body};
1114
use serde::de::DeserializeOwned;
@@ -63,7 +66,7 @@ pub fn decode<T: DeserializeOwned>(
6366
jws: &Jws<T>,
6467
key: &DecodingKey,
6568
validation: &Validation,
66-
) -> Result<TokenData<T>> {
69+
) -> Result<TokenData<Header, T>> {
6770
let header = Header::from_encoded(&jws.protected)?;
6871
let message = [jws.protected.as_str(), jws.payload.as_str()].join(".");
6972

src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ compile_error!(
1414
compile_error!("at least one of the features \"rust_crypto\" or \"aws_lc_rs\" must be enabled");
1515

1616
pub use algorithms::Algorithm;
17-
pub use decoding::{DecodingKey, TokenData, decode, decode_header};
17+
pub use decoding::{
18+
DecodingKey, TokenData, decode, decode_custom_header, decode_header, decode_with_custom_header,
19+
};
1820
pub use encoding::{EncodingKey, encode};
1921
pub use header::Header;
2022
pub use validation::{Validation, get_current_timestamp};
@@ -31,7 +33,7 @@ mod decoding;
3133
mod encoding;
3234
/// All the errors that can be encountered while encoding/decoding JWTs
3335
pub mod errors;
34-
mod header;
36+
pub mod header;
3537
pub mod jwk;
3638
pub mod jws;
3739
#[cfg(feature = "use_pem")]

tests/header/mod.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use base64::{Engine, engine::general_purpose::STANDARD};
22
use wasm_bindgen_test::wasm_bindgen_test;
33

4-
use jsonwebtoken::Header;
4+
use jsonwebtoken::{
5+
Algorithm,
6+
header::{Alg, FromEncoded, Header},
7+
};
58

69
static CERT_CHAIN: [&str; 3] = include!("cert_chain.json");
710

@@ -38,3 +41,31 @@ fn x5c_der_invalid_chain() {
3841

3942
assert!(header.x5c_der().is_err());
4043
}
44+
45+
#[test]
46+
#[wasm_bindgen_test]
47+
fn decode_custom_header() {
48+
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize)]
49+
struct CustomHeader {
50+
alg: Algorithm,
51+
typ: String,
52+
nonstandard_header: String,
53+
}
54+
impl Alg for CustomHeader {
55+
fn alg(&self) -> &Algorithm {
56+
&self.alg
57+
}
58+
}
59+
impl FromEncoded for CustomHeader {}
60+
61+
let expected = CustomHeader {
62+
alg: Algorithm::HS256,
63+
typ: "JWT".into(),
64+
nonstandard_header: "traits are awesome".into(),
65+
};
66+
67+
let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsIm5vbnN0YW5kYXJkX2hlYWRlciI6InRyYWl0cyBhcmUgYXdlc29tZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNzU5OTY4MTQ1fQ.c2VjcmV0";
68+
69+
let header = jsonwebtoken::decode_custom_header::<CustomHeader>(token).unwrap();
70+
assert_eq!(header, expected);
71+
}

0 commit comments

Comments
 (0)