@@ -34,22 +34,24 @@ bool is_whitespace(const std::string& input) {
3434}
3535
3636template <class Key_ , class Value_ >
37- c10::Dict<Key_, Value_> _map_to_c10_dict (std::unordered_map<Key_, Value_> m) {
37+ c10::Dict<Key_, Value_> _map_to_c10_dict (
38+ const std::unordered_map<Key_, Value_>& m) {
3839 c10::Dict<Key_, Value_> d;
3940 for (const auto & item : m)
4041 d.insert (item.first , item.second );
4142 return d;
4243}
4344
4445template <class Key_ , class Value_ >
45- std::unordered_map<Key_, Value_> _c10_dict_to_map (c10::Dict<Key_, Value_> d) {
46+ std::unordered_map<Key_, Value_> _c10_dict_to_map (
47+ const c10::Dict<Key_, Value_>& d) {
4648 std::unordered_map<Key_, Value_> m;
4749 for (const auto & item : d)
4850 m[item.key ()] = item.value ();
4951 return m;
5052}
5153
52- std::vector<std::string> gpt2_bpe_pre_tokenizer (std::string input) {
54+ std::vector<std::string> gpt2_bpe_pre_tokenizer (const std::string& input) {
5355 // Python implementation:
5456 // https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
5557 // Original regex contains a negative lookahead pattern, which is not
@@ -102,16 +104,16 @@ std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) {
102104}
103105
104106std::pair<std::string, std::string> split_tokens (
105- std::string s,
106- std::string delimiter) {
107+ const std::string& s,
108+ const std::string& delimiter) {
107109 auto pos = s.find (delimiter);
108110 TORCH_CHECK (pos != std::string::npos, " Expected `s`to contain `delimiter`" );
109111 return std::make_pair (s.substr (0 , pos), s.substr (pos + delimiter.length ()));
110112}
111113
112114int list_str_index (
113- std::vector<std::string> list,
114- std::string element,
115+ const std::vector<std::string>& list,
116+ const std::string& element,
115117 int start) {
116118 // Equivalent to: list.index(element, start)
117119 for (std::size_t i = start; i < list.size (); ++i) {
@@ -130,7 +132,7 @@ std::string concatenate_strings(const std::vector<std::string>& list) {
130132}
131133
132134std::vector<std::string> get_pairs (
133- std::vector<std::string> token_list,
135+ const std::vector<std::string>& token_list,
134136 const std::string& separator) {
135137 // For example: ["he", "l", "l", "o"]
136138 // ==> ["he\u0001l", "l\u0001l", "l\u0001o"]
@@ -175,7 +177,7 @@ GPT2BPEEncoder::GPT2BPEEncoder(
175177 _map_to_c10_dict<int64_t, std::string>(byte_encoder),
176178 caching_enabled) {}
177179
178- std::vector<std::string> GPT2BPEEncoder::ByteEncode_ (std::string token) {
180+ std::vector<std::string> GPT2BPEEncoder::ByteEncode_ (const std::string& token) {
179181 // Equivalent to: (self.byte_encoder[b] for b in token.encode('utf-8')
180182 std::vector<std::string> encoded;
181183 for (auto & ch : token) {
@@ -184,14 +186,15 @@ std::vector<std::string> GPT2BPEEncoder::ByteEncode_(std::string token) {
184186 return encoded;
185187}
186188
187- int64_t GPT2BPEEncoder::GetBPEMergeRank_ (std::string pair) {
189+ int64_t GPT2BPEEncoder::GetBPEMergeRank_ (const std::string& pair) {
188190 if (bpe_merge_ranks_.contains (pair)) {
189191 return bpe_merge_ranks_.at (pair);
190192 }
191193 return inf_;
192194}
193195
194- std::string GPT2BPEEncoder::FindBestPair_ (std::vector<std::string> pairs) {
196+ std::string GPT2BPEEncoder::FindBestPair_ (
197+ const std::vector<std::string>& pairs) {
195198 // Equivalent to:
196199 // min(pairs, key = lambda pair: self.bpe_merge_ranks.get(pair,
197200 // float('inf')))
@@ -277,7 +280,8 @@ std::vector<std::string> GPT2BPEEncoder::BPE_(
277280 return tok_list;
278281}
279282
280- std::vector<std::string> GPT2BPEEncoder::PreTokenize_ (std::string input) {
283+ std::vector<std::string> GPT2BPEEncoder::PreTokenize_ (
284+ const std::string& input) {
281285 return gpt2_bpe_pre_tokenizer (input);
282286}
283287
@@ -327,8 +331,8 @@ GPT2BPEEncoderStatesTorchbind _serialize_gpt2_bpe_encoder_torchbind(
327331}
328332
329333c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind (
330- GPT2BPEEncoderStatesPybind states) {
331- auto state_size = std::tuple_size<decltype (states) >::value;
334+ const GPT2BPEEncoderStatesPybind& states) {
335+ auto state_size = std::tuple_size<GPT2BPEEncoderStatesPybind >::value;
332336 TORCH_CHECK (
333337 state_size == 5 ,
334338 " Expected deserialized GPT2BPEEncoder to have 5 states but found " +
@@ -342,8 +346,8 @@ c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind(
342346}
343347
344348c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_torchbind (
345- GPT2BPEEncoderStatesTorchbind states) {
346- auto state_size = std::tuple_size<decltype (states) >::value;
349+ const GPT2BPEEncoderStatesTorchbind& states) {
350+ auto state_size = std::tuple_size<GPT2BPEEncoderStatesTorchbind >::value;
347351 TORCH_CHECK (
348352 state_size == 5 ,
349353 " Expected deserialized GPT2BPEEncoder to have 5 states but found " +
0 commit comments