|
3 | 3 | * or more contributor license agreements. Licensed under the Elastic License |
4 | 4 | * 2.0; you may not use this file except in compliance with the Elastic License |
5 | 5 | * 2.0. |
6 | | - * |
7 | | - * this file was contributed to by a generative AI |
8 | 6 | */ |
9 | 7 |
|
10 | 8 | package org.elasticsearch.xpack.core.inference.results; |
11 | 9 |
|
12 | | -import org.elasticsearch.common.Strings; |
13 | | -import org.elasticsearch.common.bytes.BytesReference; |
14 | 10 | import org.elasticsearch.common.io.stream.StreamInput; |
15 | | -import org.elasticsearch.common.io.stream.StreamOutput; |
16 | | -import org.elasticsearch.common.io.stream.Writeable; |
17 | | -import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; |
18 | | -import org.elasticsearch.inference.InferenceResults; |
19 | | -import org.elasticsearch.xcontent.ToXContent; |
20 | | -import org.elasticsearch.xcontent.ToXContentObject; |
21 | | -import org.elasticsearch.xcontent.XContent; |
22 | | -import org.elasticsearch.xcontent.XContentBuilder; |
23 | | -import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; |
24 | 11 |
|
25 | 12 | import java.io.IOException; |
26 | | -import java.util.Arrays; |
27 | | -import java.util.Iterator; |
28 | | -import java.util.LinkedHashMap; |
29 | 13 | import java.util.List; |
30 | | -import java.util.Map; |
31 | | -import java.util.Objects; |
32 | 14 |
|
33 | 15 | /** |
34 | | - * Writes a dense embedding result in the follow json format |
| 16 | + * Writes a dense embedding result in the following json format |
35 | 17 | * <pre> |
36 | 18 | * { |
37 | 19 | * "text_embedding_bytes": [ |
|
49 | 31 | * } |
50 | 32 | * </pre> |
51 | 33 | */ |
52 | | -public record DenseEmbeddingByteResults(List<Embedding> embeddings) implements DenseEmbeddingResults<DenseEmbeddingByteResults.Embedding> { |
| 34 | +public final class DenseEmbeddingByteResults extends EmbeddingByteResults { |
53 | 35 | // This name is a holdover from before this class was renamed |
54 | 36 | public static final String NAME = "text_embedding_service_byte_results"; |
55 | 37 | public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes"; |
56 | 38 |
|
57 | | - public DenseEmbeddingByteResults(StreamInput in) throws IOException { |
58 | | - this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new)); |
59 | | - } |
60 | | - |
61 | | - @Override |
62 | | - public int getFirstEmbeddingSize() { |
63 | | - if (embeddings.isEmpty()) { |
64 | | - throw new IllegalStateException("Embeddings list is empty"); |
65 | | - } |
66 | | - return embeddings.getFirst().values().length; |
67 | | - } |
68 | | - |
69 | | - @Override |
70 | | - public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) { |
71 | | - return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BYTES, embeddings.iterator()); |
| 39 | + public DenseEmbeddingByteResults(List<EmbeddingByteResults.Embedding> embeddings) { |
| 40 | + super(embeddings, TEXT_EMBEDDING_BYTES); |
72 | 41 | } |
73 | 42 |
|
74 | | - @Override |
75 | | - public void writeTo(StreamOutput out) throws IOException { |
76 | | - out.writeCollection(embeddings); |
| 43 | + public DenseEmbeddingByteResults(StreamInput in) throws IOException { |
| 44 | + super(in, TEXT_EMBEDDING_BYTES); |
77 | 45 | } |
78 | 46 |
|
79 | 47 | @Override |
80 | 48 | public String getWriteableName() { |
81 | 49 | return NAME; |
82 | 50 | } |
83 | | - |
84 | | - @Override |
85 | | - public List<? extends InferenceResults> transformToCoordinationFormat() { |
86 | | - return embeddings.stream() |
87 | | - .map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false)) |
88 | | - .toList(); |
89 | | - } |
90 | | - |
91 | | - public Map<String, Object> asMap() { |
92 | | - Map<String, Object> map = new LinkedHashMap<>(); |
93 | | - map.put(TEXT_EMBEDDING_BYTES, embeddings); |
94 | | - |
95 | | - return map; |
96 | | - } |
97 | | - |
98 | | - @Override |
99 | | - public boolean equals(Object o) { |
100 | | - if (this == o) return true; |
101 | | - if (o == null || getClass() != o.getClass()) return false; |
102 | | - DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o; |
103 | | - return Objects.equals(embeddings, that.embeddings); |
104 | | - } |
105 | | - |
106 | | - @Override |
107 | | - public int hashCode() { |
108 | | - return Objects.hash(embeddings); |
109 | | - } |
110 | | - |
111 | | - // Note: the field "numberOfMergedEmbeddings" is not serialized, so merging |
112 | | - // embeddings should happen inbetween serializations. |
113 | | - public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings) |
114 | | - implements |
115 | | - Writeable, |
116 | | - ToXContentObject, |
117 | | - EmbeddingResults.Embedding<Embedding> { |
118 | | - |
119 | | - public static final String EMBEDDING = "embedding"; |
120 | | - |
121 | | - public Embedding(byte[] values) { |
122 | | - this(values, null, 1); |
123 | | - } |
124 | | - |
125 | | - public Embedding(StreamInput in) throws IOException { |
126 | | - this(in.readByteArray()); |
127 | | - } |
128 | | - |
129 | | - @Override |
130 | | - public void writeTo(StreamOutput out) throws IOException { |
131 | | - out.writeByteArray(values); |
132 | | - } |
133 | | - |
134 | | - public static Embedding of(List<Byte> embeddingValuesList) { |
135 | | - byte[] embeddingValues = new byte[embeddingValuesList.size()]; |
136 | | - for (int i = 0; i < embeddingValuesList.size(); i++) { |
137 | | - embeddingValues[i] = embeddingValuesList.get(i); |
138 | | - } |
139 | | - return new Embedding(embeddingValues); |
140 | | - } |
141 | | - |
142 | | - @Override |
143 | | - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { |
144 | | - builder.startObject(); |
145 | | - |
146 | | - builder.startArray(EMBEDDING); |
147 | | - for (byte value : values) { |
148 | | - builder.value(value); |
149 | | - } |
150 | | - builder.endArray(); |
151 | | - |
152 | | - builder.endObject(); |
153 | | - return builder; |
154 | | - } |
155 | | - |
156 | | - @Override |
157 | | - public String toString() { |
158 | | - return Strings.toString(this); |
159 | | - } |
160 | | - |
161 | | - float[] toFloatArray() { |
162 | | - float[] floatArray = new float[values.length]; |
163 | | - for (int i = 0; i < values.length; i++) { |
164 | | - floatArray[i] = ((Byte) values[i]).floatValue(); |
165 | | - } |
166 | | - return floatArray; |
167 | | - } |
168 | | - |
169 | | - double[] toDoubleArray() { |
170 | | - double[] doubleArray = new double[values.length]; |
171 | | - for (int i = 0; i < values.length; i++) { |
172 | | - doubleArray[i] = ((Byte) values[i]).doubleValue(); |
173 | | - } |
174 | | - return doubleArray; |
175 | | - } |
176 | | - |
177 | | - @Override |
178 | | - public boolean equals(Object o) { |
179 | | - if (this == o) return true; |
180 | | - if (o == null || getClass() != o.getClass()) return false; |
181 | | - Embedding embedding = (Embedding) o; |
182 | | - return Arrays.equals(values, embedding.values); |
183 | | - } |
184 | | - |
185 | | - @Override |
186 | | - public int hashCode() { |
187 | | - return Arrays.hashCode(values); |
188 | | - } |
189 | | - |
190 | | - @Override |
191 | | - public Embedding merge(Embedding embedding) { |
192 | | - byte[] newValues = new byte[values.length]; |
193 | | - int[] newSumMergedValues = new int[values.length]; |
194 | | - int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings; |
195 | | - for (int i = 0; i < values.length; i++) { |
196 | | - newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i]) |
197 | | - + (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]); |
198 | | - // Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the |
199 | | - // closest byte instead of truncating. |
200 | | - newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings); |
201 | | - } |
202 | | - return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings); |
203 | | - } |
204 | | - |
205 | | - @Override |
206 | | - public BytesReference toBytesRef(XContent xContent) throws IOException { |
207 | | - XContentBuilder builder = XContentBuilder.builder(xContent); |
208 | | - builder.startArray(); |
209 | | - for (byte value : values) { |
210 | | - builder.value(value); |
211 | | - } |
212 | | - builder.endArray(); |
213 | | - return BytesReference.bytes(builder); |
214 | | - } |
215 | | - } |
216 | 51 | } |
0 commit comments