Skip to content

Commit 44a377f

Browse files
authored
[ML] Introduce generic dense embedding results classes (#137861)
- Create new GenericDenseEmbedding*Results classes which are identical to the existing DenseEmbedding*Results classes, but which use "embeddings" instead of "text_embedding" in the results JSON - Extract common code to abstract classes - Duplicate tests for the new DenseEmbedding*Results classes - Move the definition of the "embedding" constant String used by classes that implement the EmbeddingResult interface into that class - Replace hard-coded uses of "text_embedding", "text_embedding_bytes", "text_embedding_bits", "sparse_embedding" and "embedding" fields with references to constants - Do not use TaskType to define array names for InferenceResults classes
1 parent 9426456 commit 44a377f

28 files changed

+1228
-452
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
1414
import org.elasticsearch.inference.InferenceResults;
1515
import org.elasticsearch.inference.InferenceServiceResults;
16-
import org.elasticsearch.inference.TaskType;
1716
import org.elasticsearch.xcontent.ToXContent;
1817
import org.elasticsearch.xcontent.XContentBuilder;
1918

2019
import java.io.IOException;
2120
import java.util.Iterator;
2221
import java.util.LinkedHashMap;
2322
import java.util.List;
24-
import java.util.Locale;
2523
import java.util.Map;
2624
import java.util.stream.Collectors;
2725

@@ -42,7 +40,7 @@
4240
public record ChatCompletionResults(List<Result> results) implements InferenceServiceResults {
4341

4442
public static final String NAME = "chat_completion_service_results";
45-
public static final String COMPLETION = TaskType.COMPLETION.name().toLowerCase(Locale.ROOT);
43+
public static final String COMPLETION = "completion";
4644

4745
public ChatCompletionResults(StreamInput in) throws IOException {
4846
this(in.readCollectionAsList(Result::new));

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java

Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,17 @@
33
* or more contributor license agreements. Licensed under the Elastic License
44
* 2.0; you may not use this file except in compliance with the Elastic License
55
* 2.0.
6-
*
7-
* this file was contributed to by a generative AI
86
*/
97

108
package org.elasticsearch.xpack.core.inference.results;
119

1210
import org.elasticsearch.common.io.stream.StreamInput;
13-
import org.elasticsearch.common.io.stream.StreamOutput;
14-
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
15-
import org.elasticsearch.inference.InferenceResults;
16-
import org.elasticsearch.xcontent.ToXContent;
17-
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
1811

1912
import java.io.IOException;
20-
import java.util.Iterator;
21-
import java.util.LinkedHashMap;
2213
import java.util.List;
23-
import java.util.Map;
24-
import java.util.Objects;
2514

2615
/**
27-
* Writes a dense embedding result in the follow json format.
16+
* Writes a dense embedding result in the following json format.
2817
* <pre>
2918
* {
3019
* "text_embedding_bits": [
@@ -42,67 +31,21 @@
4231
* }
4332
* </pre>
4433
*/
45-
// Note: inheriting from DenseEmbeddingByteResults gives a bad implementation of the
46-
// Embedding.merge method for bits. TODO: implement a proper merge method
47-
public record DenseEmbeddingBitResults(List<DenseEmbeddingByteResults.Embedding> embeddings)
48-
implements
49-
DenseEmbeddingResults<DenseEmbeddingByteResults.Embedding> {
34+
public final class DenseEmbeddingBitResults extends EmbeddingBitResults {
5035
// This name is a holdover from before this class was renamed
5136
public static final String NAME = "text_embedding_service_bit_results";
5237
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
5338

54-
public DenseEmbeddingBitResults(StreamInput in) throws IOException {
55-
this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
56-
}
57-
58-
@Override
59-
public int getFirstEmbeddingSize() {
60-
if (embeddings.isEmpty()) {
61-
throw new IllegalStateException("Embeddings list is empty");
62-
}
63-
// bit embeddings are encoded as bytes so convert this to bits
64-
return Byte.SIZE * embeddings.getFirst().values().length;
39+
public DenseEmbeddingBitResults(List<EmbeddingByteResults.Embedding> embeddings) {
40+
super(embeddings, TEXT_EMBEDDING_BITS);
6541
}
6642

67-
@Override
68-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
69-
return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BITS, embeddings.iterator());
70-
}
71-
72-
@Override
73-
public void writeTo(StreamOutput out) throws IOException {
74-
out.writeCollection(embeddings);
43+
public DenseEmbeddingBitResults(StreamInput in) throws IOException {
44+
super(in, TEXT_EMBEDDING_BITS);
7545
}
7646

7747
@Override
7848
public String getWriteableName() {
7949
return NAME;
8050
}
81-
82-
@Override
83-
public List<? extends InferenceResults> transformToCoordinationFormat() {
84-
return embeddings.stream()
85-
.map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
86-
.toList();
87-
}
88-
89-
public Map<String, Object> asMap() {
90-
Map<String, Object> map = new LinkedHashMap<>();
91-
map.put(TEXT_EMBEDDING_BITS, embeddings);
92-
93-
return map;
94-
}
95-
96-
@Override
97-
public boolean equals(Object o) {
98-
if (this == o) return true;
99-
if (o == null || getClass() != o.getClass()) return false;
100-
DenseEmbeddingBitResults that = (DenseEmbeddingBitResults) o;
101-
return Objects.equals(embeddings, that.embeddings);
102-
}
103-
104-
@Override
105-
public int hashCode() {
106-
return Objects.hash(embeddings);
107-
}
10851
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java

Lines changed: 6 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,17 @@
33
* or more contributor license agreements. Licensed under the Elastic License
44
* 2.0; you may not use this file except in compliance with the Elastic License
55
* 2.0.
6-
*
7-
* this file was contributed to by a generative AI
86
*/
97

108
package org.elasticsearch.xpack.core.inference.results;
119

12-
import org.elasticsearch.common.Strings;
13-
import org.elasticsearch.common.bytes.BytesReference;
1410
import org.elasticsearch.common.io.stream.StreamInput;
15-
import org.elasticsearch.common.io.stream.StreamOutput;
16-
import org.elasticsearch.common.io.stream.Writeable;
17-
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
18-
import org.elasticsearch.inference.InferenceResults;
19-
import org.elasticsearch.xcontent.ToXContent;
20-
import org.elasticsearch.xcontent.ToXContentObject;
21-
import org.elasticsearch.xcontent.XContent;
22-
import org.elasticsearch.xcontent.XContentBuilder;
23-
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
2411

2512
import java.io.IOException;
26-
import java.util.Arrays;
27-
import java.util.Iterator;
28-
import java.util.LinkedHashMap;
2913
import java.util.List;
30-
import java.util.Map;
31-
import java.util.Objects;
3214

3315
/**
34-
* Writes a dense embedding result in the follow json format
16+
* Writes a dense embedding result in the following json format
3517
* <pre>
3618
* {
3719
* "text_embedding_bytes": [
@@ -49,168 +31,21 @@
4931
* }
5032
* </pre>
5133
*/
52-
public record DenseEmbeddingByteResults(List<Embedding> embeddings) implements DenseEmbeddingResults<DenseEmbeddingByteResults.Embedding> {
34+
public final class DenseEmbeddingByteResults extends EmbeddingByteResults {
5335
// This name is a holdover from before this class was renamed
5436
public static final String NAME = "text_embedding_service_byte_results";
5537
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
5638

57-
public DenseEmbeddingByteResults(StreamInput in) throws IOException {
58-
this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
59-
}
60-
61-
@Override
62-
public int getFirstEmbeddingSize() {
63-
if (embeddings.isEmpty()) {
64-
throw new IllegalStateException("Embeddings list is empty");
65-
}
66-
return embeddings.getFirst().values().length;
67-
}
68-
69-
@Override
70-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
71-
return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BYTES, embeddings.iterator());
39+
public DenseEmbeddingByteResults(List<EmbeddingByteResults.Embedding> embeddings) {
40+
super(embeddings, TEXT_EMBEDDING_BYTES);
7241
}
7342

74-
@Override
75-
public void writeTo(StreamOutput out) throws IOException {
76-
out.writeCollection(embeddings);
43+
public DenseEmbeddingByteResults(StreamInput in) throws IOException {
44+
super(in, TEXT_EMBEDDING_BYTES);
7745
}
7846

7947
@Override
8048
public String getWriteableName() {
8149
return NAME;
8250
}
83-
84-
@Override
85-
public List<? extends InferenceResults> transformToCoordinationFormat() {
86-
return embeddings.stream()
87-
.map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
88-
.toList();
89-
}
90-
91-
public Map<String, Object> asMap() {
92-
Map<String, Object> map = new LinkedHashMap<>();
93-
map.put(TEXT_EMBEDDING_BYTES, embeddings);
94-
95-
return map;
96-
}
97-
98-
@Override
99-
public boolean equals(Object o) {
100-
if (this == o) return true;
101-
if (o == null || getClass() != o.getClass()) return false;
102-
DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o;
103-
return Objects.equals(embeddings, that.embeddings);
104-
}
105-
106-
@Override
107-
public int hashCode() {
108-
return Objects.hash(embeddings);
109-
}
110-
111-
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
112-
// embeddings should happen inbetween serializations.
113-
public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings)
114-
implements
115-
Writeable,
116-
ToXContentObject,
117-
EmbeddingResults.Embedding<Embedding> {
118-
119-
public static final String EMBEDDING = "embedding";
120-
121-
public Embedding(byte[] values) {
122-
this(values, null, 1);
123-
}
124-
125-
public Embedding(StreamInput in) throws IOException {
126-
this(in.readByteArray());
127-
}
128-
129-
@Override
130-
public void writeTo(StreamOutput out) throws IOException {
131-
out.writeByteArray(values);
132-
}
133-
134-
public static Embedding of(List<Byte> embeddingValuesList) {
135-
byte[] embeddingValues = new byte[embeddingValuesList.size()];
136-
for (int i = 0; i < embeddingValuesList.size(); i++) {
137-
embeddingValues[i] = embeddingValuesList.get(i);
138-
}
139-
return new Embedding(embeddingValues);
140-
}
141-
142-
@Override
143-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
144-
builder.startObject();
145-
146-
builder.startArray(EMBEDDING);
147-
for (byte value : values) {
148-
builder.value(value);
149-
}
150-
builder.endArray();
151-
152-
builder.endObject();
153-
return builder;
154-
}
155-
156-
@Override
157-
public String toString() {
158-
return Strings.toString(this);
159-
}
160-
161-
float[] toFloatArray() {
162-
float[] floatArray = new float[values.length];
163-
for (int i = 0; i < values.length; i++) {
164-
floatArray[i] = ((Byte) values[i]).floatValue();
165-
}
166-
return floatArray;
167-
}
168-
169-
double[] toDoubleArray() {
170-
double[] doubleArray = new double[values.length];
171-
for (int i = 0; i < values.length; i++) {
172-
doubleArray[i] = ((Byte) values[i]).doubleValue();
173-
}
174-
return doubleArray;
175-
}
176-
177-
@Override
178-
public boolean equals(Object o) {
179-
if (this == o) return true;
180-
if (o == null || getClass() != o.getClass()) return false;
181-
Embedding embedding = (Embedding) o;
182-
return Arrays.equals(values, embedding.values);
183-
}
184-
185-
@Override
186-
public int hashCode() {
187-
return Arrays.hashCode(values);
188-
}
189-
190-
@Override
191-
public Embedding merge(Embedding embedding) {
192-
byte[] newValues = new byte[values.length];
193-
int[] newSumMergedValues = new int[values.length];
194-
int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
195-
for (int i = 0; i < values.length; i++) {
196-
newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i])
197-
+ (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]);
198-
// Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
199-
// closest byte instead of truncating.
200-
newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
201-
}
202-
return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
203-
}
204-
205-
@Override
206-
public BytesReference toBytesRef(XContent xContent) throws IOException {
207-
XContentBuilder builder = XContentBuilder.builder(xContent);
208-
builder.startArray();
209-
for (byte value : values) {
210-
builder.value(value);
211-
}
212-
builder.endArray();
213-
return BytesReference.bytes(builder);
214-
}
215-
}
21651
}

0 commit comments

Comments
 (0)