Skip to content

Commit f64f5cd

Browse files
CraigacpJaswanth51
authored andcommitted
[java] Auto EP and compile model support (microsoft#25131)
### Description Java API for compile model and EP discovery APIs. Roughly equivalent to the C# version in microsoft#24604. cc: @skottmckay. I haven't quite got the CMake configured so the Java tests for the ep registration only run when the ONNX Runtime shared provider support is built, but everything else works. I expect that to be a quick fix, but I'm not sure in what conditions it should be built and how we should handle it so I don't know where/when to plumb it through. ### Motivation and Context API parity for Java.
1 parent f158141 commit f64f5cd

22 files changed

+1484
-33
lines changed

cmake/onnxruntime_java.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ if (WIN32)
159159
if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS)
160160
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime>)
161161
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime4j_jni> ${JAVA_PACKAGE_JNI_DIR}/$<TARGET_FILE_NAME:onnxruntime4j_jni>)
162-
if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB))
162+
if (TARGET onnxruntime_providers_shared)
163163
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime_providers_shared>)
164164
endif()
165165
if (onnxruntime_USE_CUDA)
@@ -207,7 +207,7 @@ if (WIN32)
207207
else()
208208
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime>)
209209
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime4j_jni> ${JAVA_PACKAGE_JNI_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime4j_jni>)
210-
if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB))
210+
if (TARGET onnxruntime_providers_shared)
211211
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime_providers_shared>)
212212
endif()
213213
if (onnxruntime_USE_CUDA)

cmake/onnxruntime_unittests.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,10 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
16401640
add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD
16411641
COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR})
16421642
endif()
1643+
if (WIN32)
1644+
set(EXAMPLE_PLUGIN_EP_DST_FILE_NAME $<IF:$<BOOL:${WIN32}>,$<TARGET_FILE_NAME:example_plugin_ep>,$<TARGET_LINKER_FILE_NAME:example_plugin_ep>>)
1645+
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:example_plugin_ep> ${JAVA_NATIVE_TEST_DIR}/${EXAMPLE_PLUGIN_EP_DST_FILE_NAME})
1646+
endif()
16431647

16441648
# delegate to gradle's test runner
16451649

java/src/main/java/ai/onnxruntime/OnnxRuntime.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ final class OnnxRuntime {
4242
private static final int ORT_API_VERSION_13 = 13;
4343
// Post 1.13 builds of the ORT API
4444
private static final int ORT_API_VERSION_14 = 14;
45+
// Post 1.22 builds of the ORT API
46+
private static final int ORT_API_VERSION_23 = 23;
4547

4648
// The initial release of the ORT training API.
4749
private static final int ORT_TRAINING_API_VERSION_1 = 1;
@@ -103,6 +105,9 @@ final class OnnxRuntime {
103105
/** The Training API handle. */
104106
static long ortTrainingApiHandle;
105107

108+
/** The Compile API handle. */
109+
static long ortCompileApiHandle;
110+
106111
/** Is training enabled in the native library */
107112
static boolean trainingEnabled;
108113

@@ -176,12 +181,13 @@ static synchronized void init() throws IOException {
176181
}
177182
load(ONNXRUNTIME_JNI_LIBRARY_NAME);
178183

179-
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14);
184+
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23);
180185
if (ortApiHandle == 0L) {
181186
throw new IllegalStateException(
182187
"There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded");
183188
}
184-
ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_14);
189+
ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23);
190+
ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle);
185191
trainingEnabled = ortTrainingApiHandle != 0L;
186192
providers = initialiseProviders(ortApiHandle);
187193
version = initialiseVersion();
@@ -499,6 +505,14 @@ private static EnumSet<OrtProvider> initialiseProviders(long ortApiHandle) {
499505
*/
500506
private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber);
501507

508+
/**
509+
* Get a reference to the compile API struct.
510+
*
511+
* @param apiHandle The ORT API struct pointer.
512+
* @return A pointer to the compile API struct.
513+
*/
514+
private static native long initialiseCompileAPIBase(long apiHandle);
515+
502516
/**
503517
* Gets the array of available providers.
504518
*

java/src/main/java/ai/onnxruntime/OrtEnvironment.java

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
55
package ai.onnxruntime;
@@ -8,7 +8,11 @@
88
import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState;
99
import java.io.IOException;
1010
import java.nio.ByteBuffer;
11+
import java.util.ArrayList;
12+
import java.util.Collections;
1113
import java.util.EnumSet;
14+
import java.util.List;
15+
import java.util.Map;
1216
import java.util.Objects;
1317
import java.util.logging.Logger;
1418

@@ -442,6 +446,48 @@ public static EnumSet<OrtProvider> getAvailableProviders() {
442446
return OnnxRuntime.providers.clone();
443447
}
444448

449+
/**
450+
* Registers an execution provider library with this OrtEnvironment.
451+
*
452+
* @param registrationName The name to register the library with (used to remove it later with
453+
* {@link #unregisterExecutionProviderLibrary(String)}).
454+
* @param libraryPath The path to the library binary on disk.
455+
* @throws OrtException If the library could not be registered.
456+
*/
457+
public void registerExecutionProviderLibrary(String registrationName, String libraryPath)
458+
throws OrtException {
459+
registerExecutionProviderLibrary(
460+
OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath);
461+
}
462+
463+
/**
464+
* Unregisters an execution provider library from this OrtEnvironment.
465+
*
466+
* @param registrationName The name the library was registered under.
467+
* @throws OrtException If the library could not be removed.
468+
*/
469+
public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException {
470+
unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName);
471+
}
472+
473+
/**
474+
* Get the list of all execution provider and device combinations that are available.
475+
*
476+
* @see OrtSession.SessionOptions#addExecutionProvider(List, Map)
477+
* @return The list of execution provider and device combinations.
478+
* @throws OrtException If the devices could not be listed.
479+
*/
480+
public List<OrtEpDevice> getEpDevices() throws OrtException {
481+
long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle);
482+
483+
List<OrtEpDevice> devicesList = new ArrayList<>();
484+
for (long deviceHandle : deviceHandles) {
485+
devicesList.add(new OrtEpDevice(deviceHandle));
486+
}
487+
488+
return Collections.unmodifiableList(devicesList);
489+
}
490+
445491
/**
446492
* Creates the native object.
447493
*
@@ -476,6 +522,40 @@ private static native long createHandle(
476522
*/
477523
private static native long getDefaultAllocator(long apiHandle) throws OrtException;
478524

525+
/**
526+
* Registers the specified execution provider with this OrtEnvironment.
527+
*
528+
* @param apiHandle The API handle.
529+
* @param nativeHandle The OrtEnvironment handle.
530+
* @param registrationName The name of the execution provider.
531+
* @param libraryPath The path to the execution provider binary.
532+
* @throws OrtException If the registration failed.
533+
*/
534+
private static native void registerExecutionProviderLibrary(
535+
long apiHandle, long nativeHandle, String registrationName, String libraryPath)
536+
throws OrtException;
537+
538+
/**
539+
* Removes the specified execution provider from this OrtEnvironment.
540+
*
541+
* @param apiHandle The API handle.
542+
* @param nativeHandle The OrtEnvironment handle.
543+
* @param registrationName The name of the execution provider.
544+
* @throws OrtException If the removal failed.
545+
*/
546+
private static native void unregisterExecutionProviderLibrary(
547+
long apiHandle, long nativeHandle, String registrationName) throws OrtException;
548+
549+
/**
550+
* Gets handles for the EP device tuples available in this OrtEnvironment.
551+
*
552+
* @param apiHandle The API handle to use.
553+
* @param nativeHandle The OrtEnvironment handle.
554+
* @return An array of OrtEpDevice handles.
555+
* @throws OrtException If the call failed.
556+
*/
557+
private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException;
558+
479559
/**
480560
* Closes the OrtEnvironment, frees the handle.
481561
*
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+
* Licensed under the MIT License.
4+
*/
5+
package ai.onnxruntime;
6+
7+
import java.util.Map;
8+
9+
/** A tuple of Execution Provider information and the hardware device. */
10+
public final class OrtEpDevice {
11+
12+
private final long nativeHandle;
13+
14+
private final String epName;
15+
private final String epVendor;
16+
private final Map<String, String> epMetadata;
17+
private final Map<String, String> epOptions;
18+
private final OrtHardwareDevice device;
19+
20+
/**
21+
* Construct an OrtEpDevice tuple from the native pointer.
22+
*
23+
* @param nativeHandle The native pointer.
24+
*/
25+
OrtEpDevice(long nativeHandle) {
26+
this.nativeHandle = nativeHandle;
27+
this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle);
28+
this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle);
29+
String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle);
30+
this.epMetadata = OrtUtil.convertToMap(metadata);
31+
String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle);
32+
this.epOptions = OrtUtil.convertToMap(options);
33+
this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle));
34+
}
35+
36+
/**
37+
* Return the native pointer.
38+
*
39+
* @return The native pointer.
40+
*/
41+
long getNativeHandle() {
42+
return nativeHandle;
43+
}
44+
45+
/**
46+
* Gets the EP name.
47+
*
48+
* @return The EP name.
49+
*/
50+
public String getName() {
51+
return epName;
52+
}
53+
54+
/**
55+
* Gets the vendor name.
56+
*
57+
* @return The vendor name.
58+
*/
59+
public String getVendor() {
60+
return epVendor;
61+
}
62+
63+
/**
64+
* Gets an unmodifiable view on the EP metadata.
65+
*
66+
* @return The EP metadata.
67+
*/
68+
public Map<String, String> getMetadata() {
69+
return epMetadata;
70+
}
71+
72+
/**
73+
* Gets an unmodifiable view on the EP options.
74+
*
75+
* @return The EP options.
76+
*/
77+
public Map<String, String> getOptions() {
78+
return epOptions;
79+
}
80+
81+
/**
82+
* Gets the device information.
83+
*
84+
* @return The device information.
85+
*/
86+
public OrtHardwareDevice getDevice() {
87+
return device;
88+
}
89+
90+
@Override
91+
public String toString() {
92+
return "OrtEpDevice{"
93+
+ "epName='"
94+
+ epName
95+
+ '\''
96+
+ ", epVendor='"
97+
+ epVendor
98+
+ '\''
99+
+ ", epMetadata="
100+
+ epMetadata
101+
+ ", epOptions="
102+
+ epOptions
103+
+ ", device="
104+
+ device
105+
+ '}';
106+
}
107+
108+
private static native String getName(long apiHandle, long nativeHandle);
109+
110+
private static native String getVendor(long apiHandle, long nativeHandle);
111+
112+
private static native String[][] getMetadata(long apiHandle, long nativeHandle);
113+
114+
private static native String[][] getOptions(long apiHandle, long nativeHandle);
115+
116+
private static native long getDeviceHandle(long apiHandle, long nativeHandle);
117+
}

java/src/main/java/ai/onnxruntime/providers/OrtFlags.java renamed to java/src/main/java/ai/onnxruntime/OrtFlags.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/*
2-
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
5-
package ai.onnxruntime.providers;
5+
package ai.onnxruntime;
66

77
import java.util.EnumSet;
88

0 commit comments

Comments
 (0)