|
1 | 1 | /* |
2 | | - * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. |
| 2 | + * Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved. |
3 | 3 | * Licensed under the MIT License. |
4 | 4 | */ |
5 | 5 | package ai.onnxruntime; |
|
8 | 8 | import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; |
9 | 9 | import java.io.IOException; |
10 | 10 | import java.nio.ByteBuffer; |
| 11 | +import java.util.ArrayList; |
| 12 | +import java.util.Collections; |
11 | 13 | import java.util.EnumSet; |
| 14 | +import java.util.List; |
| 15 | +import java.util.Map; |
12 | 16 | import java.util.Objects; |
13 | 17 | import java.util.logging.Logger; |
14 | 18 |
|
@@ -442,6 +446,48 @@ public static EnumSet<OrtProvider> getAvailableProviders() { |
442 | 446 | return OnnxRuntime.providers.clone(); |
443 | 447 | } |
444 | 448 |
|
| 449 | + /** |
| 450 | + * Registers an execution provider library with this OrtEnvironment. |
| 451 | + * |
| 452 | + * @param registrationName The name to register the library with (used to remove it later with |
| 453 | + * {@link #unregisterExecutionProviderLibrary(String)}). |
| 454 | + * @param libraryPath The path to the library binary on disk. |
| 455 | + * @throws OrtException If the library could not be registered. |
| 456 | + */ |
| 457 | + public void registerExecutionProviderLibrary(String registrationName, String libraryPath) |
| 458 | + throws OrtException { |
| 459 | + registerExecutionProviderLibrary( |
| 460 | + OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath); |
| 461 | + } |
| 462 | + |
| 463 | + /** |
| 464 | + * Unregisters an execution provider library from this OrtEnvironment. |
| 465 | + * |
| 466 | + * @param registrationName The name the library was registered under. |
| 467 | + * @throws OrtException If the library could not be removed. |
| 468 | + */ |
| 469 | + public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException { |
| 470 | + unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName); |
| 471 | + } |
| 472 | + |
| 473 | + /** |
| 474 | + * Get the list of all execution provider and device combinations that are available. |
| 475 | + * |
| 476 | + * @see OrtSession.SessionOptions#addExecutionProvider(List, Map) |
| 477 | + * @return The list of execution provider and device combinations. |
| 478 | + * @throws OrtException If the devices could not be listed. |
| 479 | + */ |
| 480 | + public List<OrtEpDevice> getEpDevices() throws OrtException { |
| 481 | + long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle); |
| 482 | + |
| 483 | + List<OrtEpDevice> devicesList = new ArrayList<>(); |
| 484 | + for (long deviceHandle : deviceHandles) { |
| 485 | + devicesList.add(new OrtEpDevice(deviceHandle)); |
| 486 | + } |
| 487 | + |
| 488 | + return Collections.unmodifiableList(devicesList); |
| 489 | + } |
| 490 | + |
445 | 491 | /** |
446 | 492 | * Creates the native object. |
447 | 493 | * |
@@ -476,6 +522,40 @@ private static native long createHandle( |
476 | 522 | */ |
477 | 523 | private static native long getDefaultAllocator(long apiHandle) throws OrtException; |
478 | 524 |
|
| 525 | + /** |
| 526 | + * Registers the specified execution provider with this OrtEnvironment. |
| 527 | + * |
| 528 | + * @param apiHandle The API handle. |
| 529 | + * @param nativeHandle The OrtEnvironment handle. |
| 530 | + * @param registrationName The name of the execution provider. |
| 531 | + * @param libraryPath The path to the execution provider binary. |
| 532 | + * @throws OrtException If the registration failed. |
| 533 | + */ |
| 534 | + private static native void registerExecutionProviderLibrary( |
| 535 | + long apiHandle, long nativeHandle, String registrationName, String libraryPath) |
| 536 | + throws OrtException; |
| 537 | + |
| 538 | + /** |
| 539 | + * Removes the specified execution provider from this OrtEnvironment. |
| 540 | + * |
| 541 | + * @param apiHandle The API handle. |
| 542 | + * @param nativeHandle The OrtEnvironment handle. |
| 543 | + * @param registrationName The name of the execution provider. |
| 544 | + * @throws OrtException If the removal failed. |
| 545 | + */ |
| 546 | + private static native void unregisterExecutionProviderLibrary( |
| 547 | + long apiHandle, long nativeHandle, String registrationName) throws OrtException; |
| 548 | + |
| 549 | + /** |
| 550 | + * Gets handles for the EP device tuples available in this OrtEnvironment. |
| 551 | + * |
| 552 | + * @param apiHandle The API handle to use. |
| 553 | + * @param nativeHandle The OrtEnvironment handle. |
| 554 | + * @return An array of OrtEpDevice handles. |
| 555 | + * @throws OrtException If the call failed. |
| 556 | + */ |
| 557 | + private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException; |
| 558 | + |
479 | 559 | /** |
480 | 560 | * Closes the OrtEnvironment, frees the handle. |
481 | 561 | * |
|
0 commit comments