Skip to content

Commit d6a4acc

Browse files
authored
Update serialization (#1891)
* Add benchmark for deserializing large added vocab * revert dumb stuff, isolate changes * try to only normalize once * small improvement? * some updates * nit * fmt * normalized string are a fucking waste of time when you just want to add tokens to the vocab man.... * more attempts * works * let's fucking go, parity * update * hahahhahaha * revert changes that are not actually even needed * add a python test! * use normalizer before come on * nit * update to a more concrete usecase * fix build * style * reduce sample size * --allow unmaintained * clippy happy * up * up * derive impl * revert unrelated * fmt * ignore * remove stupid file
1 parent 47e4ffe commit d6a4acc

File tree

6 files changed

+117
-25
lines changed

6 files changed

+117
-25
lines changed

.github/workflows/python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ jobs:
108108
uses: actions-rs/cargo@v1
109109
with:
110110
command: audit
111-
args: -D warnings -f ./bindings/python/Cargo.lock --ignore RUSTSEC-2024-0436 --ignore RUSTSEC-2025-0014
111+
args: -D warnings -f ./bindings/python/Cargo.lock --ignore RUSTSEC-2024-0436 --ignore RUSTSEC-2025-0014 --ignore RUSTSEC-2025-0119 --ignore RUSTSEC-2024-0436
112112

113113
- name: Install
114114
working-directory: ./bindings/python

.github/workflows/rust.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
uses: actions-rs/cargo@v1
9595
with:
9696
command: audit
97-
args: -D warnings -f ./tokenizers/Cargo.lock --ignore RUSTSEC-2024-0436 --ignore RUSTSEC-2025-0014
97+
args: -D warnings -f ./tokenizers/Cargo.lock --ignore RUSTSEC-2024-0436 --ignore RUSTSEC-2025-0014 --ignore RUSTSEC-2025-0119
9898

9999
# Verify that Readme.md is up to date.
100100
- name: Make sure, Readme generated from lib.rs matches actual Readme

tokenizers/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ harness = false
4040
name = "llama3_benchmark"
4141
harness = false
4242

43+
[[bench]]
44+
name = "added_vocab_deserialize"
45+
required-features = ["http"]
46+
harness = false
47+
4348
[dependencies]
4449
rand = "0.9"
4550
onig = { version = "6.5.1", default-features = false, optional = true }
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#[macro_use]
2+
extern crate criterion;
3+
use criterion::Criterion;
4+
use std::hint::black_box;
5+
use std::str::FromStr;
6+
use tokenizers::{normalizers::*, AddedToken, Normalizer, Tokenizer};
7+
8+
fn serialized_tokenizer<N: Normalizer + Into<NormalizerWrapper>>(
9+
size: i64,
10+
normalizer: Option<N>,
11+
special_tokens: bool,
12+
) -> String {
13+
let mut tokenizer = Tokenizer::from_pretrained("t5-small", None).unwrap();
14+
15+
if let Some(norm) = normalizer {
16+
tokenizer.with_normalizer(Some(norm));
17+
}
18+
19+
let tokens: Vec<_> = (0..size)
20+
.map(|i| AddedToken::from(format!("tok{i}"), special_tokens))
21+
.collect();
22+
tokenizer.add_tokens(&tokens);
23+
24+
serde_json::to_string(&tokenizer).unwrap()
25+
}
26+
27+
#[allow(clippy::type_complexity)]
28+
fn bench_deserialize(c: &mut Criterion) {
29+
let normalizers: Vec<(&str, Option<fn() -> NormalizerWrapper>)> = vec![
30+
("none", None),
31+
("byte_level", Some(|| ByteLevel.into())),
32+
("lowercase", Some(|| Lowercase.into())),
33+
("nfc", Some(|| NFC.into())),
34+
("nfd", Some(|| NFD.into())),
35+
("nfkc", Some(|| NFKC.into())),
36+
("nfkd", Some(|| NFKD.into())),
37+
("nmt", Some(|| Nmt.into())),
38+
("strip", Some(|| Strip::new(true, true).into())),
39+
("replace", Some(|| Replace::new("a", "b").unwrap().into())),
40+
("prepend", Some(|| Prepend::new("pre_".to_string()).into())),
41+
("bert", Some(|| BertNormalizer::default().into())),
42+
];
43+
44+
for &size in &[100_000, 400_000] {
45+
for (norm_name, maybe_factory) in &normalizers {
46+
let label = format!(
47+
"special tokens deserialize_added_vocab_{}_norm_{}",
48+
size, norm_name
49+
);
50+
51+
let json = match maybe_factory {
52+
Some(factory) => serialized_tokenizer(size, Some(factory()), true),
53+
None => serialized_tokenizer::<NormalizerWrapper>(size, None, true),
54+
};
55+
c.bench_function(&label, |b| {
56+
b.iter(|| {
57+
let tok: Tokenizer = black_box(Tokenizer::from_str(&json).unwrap());
58+
black_box(tok);
59+
})
60+
});
61+
62+
let label = format!(
63+
"non special deserialize_added_vocab_{}_norm_{}",
64+
size, norm_name
65+
);
66+
67+
let json = match maybe_factory {
68+
Some(factory) => serialized_tokenizer(size, Some(factory()), false),
69+
None => serialized_tokenizer::<NormalizerWrapper>(size, None, false),
70+
};
71+
c.bench_function(&label, |b| {
72+
b.iter(|| {
73+
let tok: Tokenizer = black_box(Tokenizer::from_str(&json).unwrap());
74+
black_box(tok);
75+
})
76+
});
77+
}
78+
}
79+
}
80+
81+
criterion_group! {
82+
name = benches;
83+
config = Criterion::default().significance_level(0.1).sample_size(10);
84+
targets = bench_deserialize
85+
}
86+
criterion_main!(benches);

tokenizers/src/tokenizer/added_vocabulary.rs

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,30 +272,35 @@ impl AddedVocabulary {
272272
}
273273
}
274274

275-
// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
276275
let mut ignored = 0;
276+
277+
let mut existing: AHashSet<AddedToken> =
278+
self.added_tokens_map_r.values().cloned().collect();
279+
let mut next_id = self.added_tokens_map_r.keys().copied().max().map_or(
280+
model.get_vocab_size() as u32,
281+
|max| {
282+
if max >= model.get_vocab_size() as u32 || model.get_vocab_size() == 0 {
283+
max + 1
284+
} else {
285+
model.get_vocab_size() as u32
286+
}
287+
},
288+
);
289+
277290
for token in tokens {
278-
if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
279-
{
291+
if token.content.is_empty() || existing.contains(token) {
280292
ignored += 1;
281293
continue;
282294
}
283-
// If a token is already part of the vocabulary, we mark it as added
295+
284296
let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
285297
new_id
286298
} else {
287-
self.added_tokens_map.values().cloned().max().map_or(
288-
model.get_vocab_size() as u32,
289-
|max| {
290-
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
291-
max + 1
292-
} else {
293-
model.get_vocab_size() as u32
294-
}
295-
},
296-
)
299+
let id = next_id;
300+
next_id += 1;
301+
id
297302
};
298-
// Make sure we modify the previous entry
303+
299304
*self
300305
.added_tokens_map
301306
.entry(token.content.clone())
@@ -308,6 +313,7 @@ impl AddedVocabulary {
308313
if !self.special_tokens_set.contains(&token.content) {
309314
self.added_tokens.push(token.clone());
310315
}
316+
existing.insert(token.clone());
311317
}
312318

313319
self.refresh_added_tokens(model, normalizer);
@@ -317,7 +323,7 @@ impl AddedVocabulary {
317323
}
318324

319325
/// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
320-
///
326+
/// # TODO @ArthurZucker we should probably make this async? rebuilding the regex takes a long time.
321327
/// We keep two different RegexSet, one that will take care of matching against the
322328
/// non-normalized string, and one matching against the normalized one.
323329
fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {

tokenizers/src/utils/truncation.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,14 @@ pub enum TruncationError {
4949
SequenceTooShort,
5050
}
5151

52-
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)]
52+
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)]
5353
pub enum TruncationStrategy {
54+
#[default]
5455
LongestFirst,
5556
OnlyFirst,
5657
OnlySecond,
5758
}
5859

59-
impl Default for TruncationStrategy {
60-
fn default() -> Self {
61-
Self::LongestFirst
62-
}
63-
}
64-
6560
impl std::convert::AsRef<str> for TruncationStrategy {
6661
fn as_ref(&self) -> &str {
6762
match self {

0 commit comments

Comments
 (0)