Skip to content

Commit a0f8ec1

Browse files
committed
Optimize MistralAiEmbeddingModel dimensions method
- Calculate and cache values for unknown models only if necessary - Make known embedding dimensions a mutable map attribute - Verify the cache mechanism with MistralAiEmbeddingModelTests Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 47e4232 commit a0f8ec1

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.HashMap;
1920
import java.util.List;
2021
import java.util.Map;
2122

@@ -56,16 +57,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
5657

5758
private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class);
5859

60+
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
61+
5962
/**
6063
* Known embedding dimensions for Mistral AI models. Maps model names to their
6164
* respective embedding vector dimensions. This allows the dimensions() method to
6265
* return the correct value without making an API call.
6366
*/
64-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Map.of(
65-
MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024, MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(),
66-
1536);
67-
68-
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
67+
private final Map<String, Integer> knownEmbeddingDimensions = createKnownEmbeddingDimensions();
6968

7069
private final MistralAiEmbeddingOptions defaultOptions;
7170

@@ -85,6 +84,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
8584
*/
8685
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
8786

87+
private static Map<String, Integer> createKnownEmbeddingDimensions() {
88+
Map<String, Integer> knownEmbeddingDimensions = new HashMap<>();
89+
knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024);
90+
knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(), 1536);
91+
92+
return knownEmbeddingDimensions;
93+
}
94+
8895
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode,
8996
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
9097
Assert.notNull(mistralAiApi, "mistralAiApi must not be null");
@@ -174,7 +181,8 @@ public float[] embed(Document document) {
174181

175182
@Override
176183
public int dimensions() {
177-
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
184+
return this.knownEmbeddingDimensions.computeIfAbsent(this.defaultOptions.getModel(),
185+
model -> super.dimensions());
178186
}
179187

180188
/**

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.Arrays;
1920
import java.util.List;
2021

2122
import org.junit.jupiter.api.Test;
@@ -28,6 +29,7 @@
2829

2930
import static org.assertj.core.api.Assertions.assertThat;
3031
import static org.mockito.ArgumentMatchers.any;
32+
import static org.mockito.Mockito.verify;
3133
import static org.mockito.Mockito.when;
3234

3335
/**
@@ -77,7 +79,7 @@ void testDimensionsForCodestralEmbedModel() {
7779
void testDimensionsFallbackForUnknownModel() {
7880
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(512);
7981

80-
// Use a model name that doesn't exist in KNOWN_EMBEDDING_DIMENSIONS
82+
// Use a model name that doesn't exist in knownEmbeddingDimensions.
8183
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder().withModel("unknown-model").build();
8284

8385
MistralAiEmbeddingModel model = MistralAiEmbeddingModel.builder()
@@ -87,17 +89,23 @@ void testDimensionsFallbackForUnknownModel() {
8789
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8890
.build();
8991

90-
// Should fall back to super.dimensions() which detects dimensions from the API
91-
// response
92+
// For the first call, it should fall back to super.dimensions() which detects
93+
// dimensions from the API response.
9294
assertThat(model.dimensions()).isEqualTo(512);
95+
96+
// For the second call, it should use the cache mechanism.
97+
assertThat(model.dimensions()).isEqualTo(512);
98+
99+
// Verify that super.dimensions() has been called once.
100+
verify(mockApi).embeddings(any());
93101
}
94102

95103
@Test
96104
void testAllEmbeddingModelsHaveDimensionMapping() {
97-
// This test ensures that KNOWN_EMBEDDING_DIMENSIONS map stays in sync with the
98-
// EmbeddingModel enum
105+
// This test ensures that knownEmbeddingDimensions map stays in sync with the
106+
// EmbeddingModel enum.
99107
// If a new model is added to the enum but not to the dimensions map, this test
100-
// will help catch it
108+
// will help catch it.
101109

102110
for (MistralAiApi.EmbeddingModel embeddingModel : MistralAiApi.EmbeddingModel.values()) {
103111
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(1024);
@@ -138,16 +146,13 @@ private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) {
138146

139147
// Create a mock embedding response with the specified dimensions
140148
float[] embedding = new float[dimensions];
141-
for (int i = 0; i < dimensions; i++) {
142-
embedding[i] = 0.1f;
143-
}
149+
Arrays.fill(embedding, 0.1f);
144150

145151
MistralAiApi.Embedding embeddingData = new MistralAiApi.Embedding(0, embedding, "embedding");
146152

147153
MistralAiApi.Usage usage = new MistralAiApi.Usage(10, 0, 10);
148154

149-
MistralAiApi.EmbeddingList embeddingList = new MistralAiApi.EmbeddingList("object", List.of(embeddingData),
150-
"model", usage);
155+
var embeddingList = new MistralAiApi.EmbeddingList<>("object", List.of(embeddingData), "model", usage);
151156

152157
when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingList));
153158

0 commit comments

Comments
 (0)