Skip to content

Commit 1eeb62f

Browse files
authored
feat: sort input for SentenceTransformerEmbed to save padding cost (#1245)
1 parent 7de2f5c commit 1eeb62f

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

python/cocoindex/functions/sbert.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""SentenceTransformer embedding functionality."""
22

3-
import dataclasses
4-
from typing import Any, Literal
3+
from typing import Any, Literal, cast
54

65
import numpy as np
76
from numpy.typing import NDArray
@@ -60,7 +59,18 @@ def analyze(self) -> type:
6059

6160
def __call__(self, text: list[str]) -> list[NDArray[np.float32]]:
6261
assert self._model is not None
62+
63+
# Sort the text by length to minimize the number of padding tokens.
64+
text_with_idx = [(idx, t) for idx, t in enumerate(text)]
65+
text_with_idx.sort(key=lambda x: len(x[1]))
66+
6367
results: list[NDArray[np.float32]] = self._model.encode(
64-
text, convert_to_numpy=True
68+
[t for _, t in text_with_idx], convert_to_numpy=True
6569
)
66-
return results
70+
final_results: list[NDArray[np.float32] | None] = [
71+
None for _ in range(len(text))
72+
]
73+
for (idx, _), result in zip(text_with_idx, results):
74+
final_results[idx] = result
75+
76+
return cast(list[NDArray[np.float32]], final_results)

0 commit comments

Comments
 (0)