Skip to content

Commit 1cde7e2

Browse files
committed
Optimize MistralAiEmbeddingModel dimensions method
- Calculate and cache values for unknown models only if necessary - Make known embedding dimensions a mutable map attribute - Verify the cache mechanism with MistralAiEmbeddingModelTests Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 944029a commit 1cde7e2

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import static org.assertj.core.api.Assertions.assertThat;
3131
import static org.mockito.ArgumentMatchers.any;
32+
import static org.mockito.Mockito.verify;
3233
import static org.mockito.Mockito.when;
3334

3435
/**
@@ -78,7 +79,7 @@ void testDimensionsForCodestralEmbedModel() {
7879
void testDimensionsFallbackForUnknownModel() {
7980
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(512);
8081

81-
// Use a model name that doesn't exist in KNOWN_EMBEDDING_DIMENSIONS
82+
// Use a model name that doesn't exist in knownEmbeddingDimensions.
8283
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder().withModel("unknown-model").build();
8384

8485
MistralAiEmbeddingModel model = MistralAiEmbeddingModel.builder()
@@ -88,17 +89,23 @@ void testDimensionsFallbackForUnknownModel() {
8889
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8990
.build();
9091

91-
// Should fall back to super.dimensions() which detects dimensions from the API
92-
// response
92+
// For the first call, it should fall back to super.dimensions() which detects
93+
// dimensions from the API response.
9394
assertThat(model.dimensions()).isEqualTo(512);
95+
96+
// For the second call, it should use the cache mechanism.
97+
assertThat(model.dimensions()).isEqualTo(512);
98+
99+
// Verify that super.dimensions() has been called once.
100+
verify(mockApi).embeddings(any());
94101
}
95102

96103
@Test
97104
void testAllEmbeddingModelsHaveDimensionMapping() {
98105
// This test ensures that knownEmbeddingDimensions map stays in sync with the
99-
// EmbeddingModel enum
106+
// EmbeddingModel enum.
100107
// If a new model is added to the enum but not to the dimensions map, this test
101-
// will help catch it
108+
// will help catch it.
102109

103110
for (MistralAiApi.EmbeddingModel embeddingModel : MistralAiApi.EmbeddingModel.values()) {
104111
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(1024);

0 commit comments

Comments
 (0)