Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit e76aa87

Browse files
laithsakkafacebook-github-bot
authored andcommitted
Fix performance issues in gpt2_bpe_tokenizer (#401)
Summary: Pull Request resolved: #401 complex structures in c++ should be passed as const ref instead of value to avoid data copy. A bunch of functions was passing by value gpt2_bpe_tokenizer Reviewed By: kevinwilfong, wenleix, Nayef211 Differential Revision: D37423480 fbshipit-source-id: 9b0dd5bf428214204d5d0322a08bd8506822382c
1 parent 0c54a46 commit e76aa87

File tree

2 files changed

+32
-28
lines changed

2 files changed

+32
-28
lines changed

csrc/velox/functions/text/gpt2_bpe_tokenizer.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,24 @@ bool is_whitespace(const std::string& input) {
3434
}
3535

3636
template <class Key_, class Value_>
37-
c10::Dict<Key_, Value_> _map_to_c10_dict(std::unordered_map<Key_, Value_> m) {
37+
c10::Dict<Key_, Value_> _map_to_c10_dict(
38+
const std::unordered_map<Key_, Value_>& m) {
3839
c10::Dict<Key_, Value_> d;
3940
for (const auto& item : m)
4041
d.insert(item.first, item.second);
4142
return d;
4243
}
4344

4445
template <class Key_, class Value_>
45-
std::unordered_map<Key_, Value_> _c10_dict_to_map(c10::Dict<Key_, Value_> d) {
46+
std::unordered_map<Key_, Value_> _c10_dict_to_map(
47+
const c10::Dict<Key_, Value_>& d) {
4648
std::unordered_map<Key_, Value_> m;
4749
for (const auto& item : d)
4850
m[item.key()] = item.value();
4951
return m;
5052
}
5153

52-
std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) {
54+
std::vector<std::string> gpt2_bpe_pre_tokenizer(const std::string& input) {
5355
// Python implementation:
5456
// https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
5557
// Original regex contains a negative lookahead pattern, which is not
@@ -102,16 +104,16 @@ std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) {
102104
}
103105

104106
std::pair<std::string, std::string> split_tokens(
105-
std::string s,
106-
std::string delimiter) {
107+
const std::string& s,
108+
const std::string& delimiter) {
107109
auto pos = s.find(delimiter);
108110
TORCH_CHECK(pos != std::string::npos, "Expected `s`to contain `delimiter`");
109111
return std::make_pair(s.substr(0, pos), s.substr(pos + delimiter.length()));
110112
}
111113

112114
int list_str_index(
113-
std::vector<std::string> list,
114-
std::string element,
115+
const std::vector<std::string>& list,
116+
const std::string& element,
115117
int start) {
116118
// Equivalent to: list.index(element, start)
117119
for (std::size_t i = start; i < list.size(); ++i) {
@@ -130,7 +132,7 @@ std::string concatenate_strings(const std::vector<std::string>& list) {
130132
}
131133

132134
std::vector<std::string> get_pairs(
133-
std::vector<std::string> token_list,
135+
const std::vector<std::string>& token_list,
134136
const std::string& separator) {
135137
// For example: ["he", "l", "l", "o"]
136138
// ==> ["he\u0001l", "l\u0001l", "l\u0001o"]
@@ -175,7 +177,7 @@ GPT2BPEEncoder::GPT2BPEEncoder(
175177
_map_to_c10_dict<int64_t, std::string>(byte_encoder),
176178
caching_enabled) {}
177179

178-
std::vector<std::string> GPT2BPEEncoder::ByteEncode_(std::string token) {
180+
std::vector<std::string> GPT2BPEEncoder::ByteEncode_(const std::string& token) {
179181
// Equivalent to: (self.byte_encoder[b] for b in token.encode('utf-8')
180182
std::vector<std::string> encoded;
181183
for (auto& ch : token) {
@@ -184,14 +186,15 @@ std::vector<std::string> GPT2BPEEncoder::ByteEncode_(std::string token) {
184186
return encoded;
185187
}
186188

187-
int64_t GPT2BPEEncoder::GetBPEMergeRank_(std::string pair) {
189+
int64_t GPT2BPEEncoder::GetBPEMergeRank_(const std::string& pair) {
188190
if (bpe_merge_ranks_.contains(pair)) {
189191
return bpe_merge_ranks_.at(pair);
190192
}
191193
return inf_;
192194
}
193195

194-
std::string GPT2BPEEncoder::FindBestPair_(std::vector<std::string> pairs) {
196+
std::string GPT2BPEEncoder::FindBestPair_(
197+
const std::vector<std::string>& pairs) {
195198
// Equivalent to:
196199
// min(pairs, key = lambda pair: self.bpe_merge_ranks.get(pair,
197200
// float('inf')))
@@ -277,7 +280,8 @@ std::vector<std::string> GPT2BPEEncoder::BPE_(
277280
return tok_list;
278281
}
279282

280-
std::vector<std::string> GPT2BPEEncoder::PreTokenize_(std::string input) {
283+
std::vector<std::string> GPT2BPEEncoder::PreTokenize_(
284+
const std::string& input) {
281285
return gpt2_bpe_pre_tokenizer(input);
282286
}
283287

@@ -327,8 +331,8 @@ GPT2BPEEncoderStatesTorchbind _serialize_gpt2_bpe_encoder_torchbind(
327331
}
328332

329333
c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind(
330-
GPT2BPEEncoderStatesPybind states) {
331-
auto state_size = std::tuple_size<decltype(states)>::value;
334+
const GPT2BPEEncoderStatesPybind& states) {
335+
auto state_size = std::tuple_size<GPT2BPEEncoderStatesPybind>::value;
332336
TORCH_CHECK(
333337
state_size == 5,
334338
"Expected deserialized GPT2BPEEncoder to have 5 states but found " +
@@ -342,8 +346,8 @@ c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind(
342346
}
343347

344348
c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_torchbind(
345-
GPT2BPEEncoderStatesTorchbind states) {
346-
auto state_size = std::tuple_size<decltype(states)>::value;
349+
const GPT2BPEEncoderStatesTorchbind& states) {
350+
auto state_size = std::tuple_size<GPT2BPEEncoderStatesTorchbind>::value;
347351
TORCH_CHECK(
348352
state_size == 5,
349353
"Expected deserialized GPT2BPEEncoder to have 5 states but found " +

csrc/velox/functions/text/gpt2_bpe_tokenizer.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,42 +42,42 @@ typedef std::tuple<
4242

4343
// Applies regex based pre-tokenization step for GPT-2 BPE tokenizer
4444
// and returns a list of tokens.
45-
std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input);
45+
std::vector<std::string> gpt2_bpe_pre_tokenizer(const std::string& input);
4646

4747
// Concatenate a vector of strings to a single string
4848
std::string concatenate_strings(const std::vector<std::string>& list);
4949

5050
// Return set of token pairs in a word, separated by the `separator`.
5151
std::vector<std::string> get_pairs(
52-
std::vector<std::string> token_list,
52+
const std::vector<std::string>& token_list,
5353
const std::string& separator);
5454

5555
// Split a string into 2 parts separated by a `separator`.
5656
std::pair<std::string, std::string> split_tokens(
57-
std::string s,
58-
std::string delimiter);
57+
const std::string& s,
58+
const std::string& delimiter);
5959

6060
// Find index of `element` in a list of strings.
6161
int list_str_index(
62-
std::vector<std::string> list,
63-
std::string element,
62+
const std::vector<std::string>& list,
63+
const std::string& element,
6464
int start);
6565

6666
struct GPT2BPEEncoder : torch::CustomClassHolder {
6767
private:
6868
const int64_t inf_;
6969
// Encode byte into an unicode character.
70-
std::vector<std::string> ByteEncode_(std::string token);
71-
int64_t GetBPEMergeRank_(std::string pair);
70+
std::vector<std::string> ByteEncode_(const std::string& token);
71+
int64_t GetBPEMergeRank_(const std::string& pair);
7272

7373
protected:
7474
c10::Dict<std::string, std::vector<std::string>> cache_;
75-
virtual std::vector<std::string> PreTokenize_(std::string input);
75+
virtual std::vector<std::string> PreTokenize_(const std::string& input);
7676
// Return a list of bpe tokens.
7777
virtual std::vector<std::string> BPE_(
7878
const std::vector<std::string>& token_list);
7979
// Return the token pair(e.g bpe merge) with lowest rank.
80-
std::string FindBestPair_(std::vector<std::string> pairs);
80+
std::string FindBestPair_(const std::vector<std::string>& pairs);
8181

8282
public:
8383
const c10::Dict<std::string, int64_t> bpe_encoder_;
@@ -122,9 +122,9 @@ GPT2BPEEncoderStatesPybind _serialize_gpt2_bpe_encoder_pybind(
122122
GPT2BPEEncoderStatesTorchbind _serialize_gpt2_bpe_encoder_torchbind(
123123
const c10::intrusive_ptr<GPT2BPEEncoder>& self);
124124
c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind(
125-
GPT2BPEEncoderStatesPybind states);
125+
const GPT2BPEEncoderStatesPybind& states);
126126
c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_torchbind(
127-
GPT2BPEEncoderStatesTorchbind states);
127+
const GPT2BPEEncoderStatesTorchbind& states);
128128
} // namespace facebook::torcharrow::functions
129129

130130
#endif // GPT2_BPE_TOKENIZER_H_

0 commit comments

Comments
 (0)