Skip to content

Commit 1e00002

Browse files
fergushendersoncopybara-github
authored andcommitted
Add new sample using TFLite in Play services C++ API into new "cc_api" subdirectory.
This includes (somewhat hacky) CMake code to extract the TFLite in Play services C++ API SDK (`tflite_cc_api`) from the AAR file in the `play_services_tflite_java` Maven package, and build it, and to build and link the sample app against it. PiperOrigin-RevId: 704288200
1 parent e26d3f5 commit 1e00002

File tree

27 files changed

+1372
-0
lines changed

27 files changed

+1372
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
apply plugin: 'com.android.application'
2+
apply from: 'tflite-java-extract-cpp-sdk.gradle'
3+
4+
android {
5+
namespace 'com.google.samples.gms.tflite.cc'
6+
compileSdk 33
7+
8+
defaultConfig {
9+
applicationId 'com.google.samples.gms.tflite.cc'
10+
minSdk 21
11+
targetSdk 33
12+
versionCode 1
13+
versionName '1.0'
14+
15+
testInstrumentationRunner 'androidx.test.runner.AndroidJUnitRunner'
16+
}
17+
18+
buildTypes {
19+
all {
20+
proguardFiles 'proguard-rules.pro'
21+
}
22+
release {
23+
minifyEnabled true
24+
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
25+
}
26+
}
27+
28+
buildFeatures {
29+
prefab true
30+
}
31+
32+
externalNativeBuild {
33+
cmake {
34+
path file('src/main/cpp/CMakeLists.txt')
35+
}
36+
}
37+
38+
aaptOptions {
39+
noCompress 'bin'
40+
}
41+
}
42+
43+
java {
44+
toolchain {
45+
languageVersion.set(JavaLanguageVersion.of(11))
46+
}
47+
}
48+
49+
dependencies {
50+
51+
implementation 'androidx.appcompat:appcompat:1.6.0'
52+
implementation 'com.google.android.material:material:1.8.0'
53+
54+
androidTestImplementation 'androidx.test:rules:1.4.0'
55+
androidTestImplementation 'androidx.test:runner:1.4.0'
56+
androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
57+
androidTestImplementation 'androidx.test.ext:junit:1.1.3'
58+
androidTestImplementation 'com.google.truth:truth:1.1.3'
59+
60+
implementation 'com.google.android.gms:play-services-tflite-java:16.4.0'
61+
implementation 'com.google.android.gms:play-services-tflite-gpu:16.4.0'
62+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
-keepclasseswithmembers class com.google.samples.gms.tflite.cc.MainActivity { *; }
2+
-keepclasseswithmembers class com.google.samples.gms.tflite.cc.TfLiteJni { *; }
3+
-keepclasseswithmembers class android.support.** { *; }
4+
-keepclasseswithmembers class androidx.** { *; }
5+
6+
# The tests use Tasks.await, but the instrumented code doesn't
7+
# Without the appropriate Proguard config, the method would be pruned
8+
-keepclasseswithmembers class com.google.android.gms.tasks.** { *; }
9+
# The tests also uses TfLiteNative.initialize(context)
10+
-keepclasseswithmembers class com.google.android.gms.tflite.java.TfLiteNative { *; }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2023 The TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.samples.gms.tflite.c.instrumentation;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import android.content.Context;
22+
import android.util.Log;
23+
import androidx.test.core.app.ApplicationProvider;
24+
import androidx.test.ext.junit.runners.AndroidJUnit4;
25+
import com.google.android.gms.tasks.Tasks;
26+
import com.google.android.gms.tflite.java.TfLiteNative;
27+
import com.google.samples.gms.tflite.c.TfLiteJni;
28+
import java.util.concurrent.ExecutionException;
29+
import org.junit.Test;
30+
import org.junit.runner.RunWith;
31+
32+
/** Instrumentation tests for the TFLite Native API. */
33+
@RunWith(AndroidJUnit4.class)
34+
public class BasicScenarioTest {
35+
private static final String TAG = "BasicScenarioTest";
36+
37+
@Test
38+
public void basicScenario() throws ExecutionException, InterruptedException {
39+
Context context = ApplicationProvider.getApplicationContext();
40+
Tasks.await(TfLiteNative.initialize(context));
41+
TfLiteJni jni = new TfLiteJni(message -> Log.e(TAG, message));
42+
43+
jni.loadModel(context.getAssets(), "add.tflite");
44+
float[] output = jni.runInference(new float[] {1.f, 3.f});
45+
jni.destroy();
46+
47+
assertThat(output).isEqualTo(new float[] {3.f, 9.f});
48+
}
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2023 The TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.samples.gms.tflite.c.instrumentation;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import android.content.Context;
22+
import android.util.Log;
23+
import androidx.test.core.app.ApplicationProvider;
24+
import androidx.test.ext.junit.runners.AndroidJUnit4;
25+
import com.google.android.gms.tasks.Tasks;
26+
import com.google.android.gms.tflite.client.TfLiteInitializationOptions;
27+
import com.google.android.gms.tflite.gpu.support.TfLiteGpu;
28+
import com.google.android.gms.tflite.java.TfLiteNative;
29+
import com.google.samples.gms.tflite.c.TfLiteJni;
30+
import java.util.concurrent.ExecutionException;
31+
import org.junit.Assume;
32+
import org.junit.Before;
33+
import org.junit.Test;
34+
import org.junit.runner.RunWith;
35+
36+
/** Instrumentation tests for the TFLite Native Acceleration API. */
37+
@RunWith(AndroidJUnit4.class)
38+
public class TfLiteNativeGPUAccelerationTest {
39+
private static final String TAG = "TfLiteNativeGPUAccelerationTest";
40+
41+
private static final TfLiteInitializationOptions options =
42+
TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build();
43+
44+
private Context context;
45+
46+
@Before
47+
public void setUp() throws ExecutionException, InterruptedException {
48+
context = ApplicationProvider.getApplicationContext();
49+
boolean gpuAvailable = Tasks.await(TfLiteGpu.isGpuDelegateAvailable(context));
50+
Assume.assumeTrue("GPU acceleration is unavailable on this device.", gpuAvailable);
51+
Tasks.await(TfLiteNative.initialize(context, options));
52+
}
53+
54+
@Test
55+
public void doInferenceWithAcceleration() {
56+
TfLiteJni jni = new TfLiteJni(message -> Log.e(TAG, message));
57+
jni.initGpuAcceleration();
58+
jni.loadModel(context.getAssets(), "add.tflite");
59+
float[] output = jni.runInference(new float[] {1.f, 3.f});
60+
jni.destroy();
61+
assertThat(output).isEqualTo(new float[] {3.f, 9.f});
62+
}
63+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
3+
xmlns:tools="http://schemas.android.com/tools">
4+
5+
<application
6+
android:allowBackup="true"
7+
android:icon="@mipmap/ic_launcher"
8+
android:label="@string/app_name"
9+
android:taskAffinity=""
10+
android:theme="@style/AppTheme.TFLite">
11+
<activity
12+
android:name="com.google.samples.gms.tflite.cc.MainActivity"
13+
android:exported="true">
14+
<intent-filter>
15+
<action android:name="android.intent.action.MAIN" />
16+
<category android:name="android.intent.category.LAUNCHER" />
17+
</intent-filter>
18+
</activity>
19+
</application>
20+
21+
<queries>
22+
<package android:name="com.google.android.gms.policy_tflite_dynamite_dynamite" />
23+
</queries>
24+
25+
</manifest>
Binary file not shown.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
cmake_minimum_required(VERSION 3.16.3)
2+
3+
project("tflite-cc-sample")
4+
5+
#-----------------------------------------------------------------------------
6+
# Set up FlatBuffers dependency.
7+
# (tflite_cc_api requires absl, and absl requires flatbuffers.)
8+
9+
include(FetchContent)
10+
FetchContent_Declare(
11+
flatbuffers
12+
GIT_REPOSITORY https://github.com/google/flatbuffers.git
13+
# Keep in sync with the FlatBuffers version in the most recent version
14+
# of TF Lite (Lite RT) for Play services published on maven.google.com
15+
# <https://maven.google.com/web/index.html#com.google.android.gms:play-services-tflite-java>
16+
# and in particular as determined by the `static_assert` in the file
17+
# `tensorflow/lite/acceleration/configuration/configuration_generated.h`
18+
# in the `prefab/modules/tensorflowlite_jni_gms_client/include/` directory
19+
# of the `play-services-tflite-java-<version>.aar` file there.
20+
#
21+
# That version is currently "GIT_TAG v24.3.25",
22+
# but due to a build error in v24.3.25, we use a snapshot that is three PRs
23+
# later (all bug fixes), which fixes the build error.
24+
# See <https://github.com/google/flatbuffers/commit/e6463926479bd6b330cbcf673f7e917803fd5831>.
25+
GIT_TAG e6463926479bd6b330cbcf673f7e917803fd5831
26+
)
27+
set(FLATBUFFERS_BUILD_FLATC OFF)
28+
set(FLATBUFFERS_BUILD_TESTS OFF)
29+
set(FLATBUFFERS_INSTALL OFF)
30+
FetchContent_MakeAvailable(flatbuffers)
31+
32+
#-----------------------------------------------------------------------------
33+
# Set up ABSL dependency.
34+
35+
FetchContent_Declare(
36+
absl
37+
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
38+
GIT_TAG 20240722.0 # Released on Aug 1, 2024
39+
)
40+
FetchContent_MakeAvailable(absl)
41+
42+
#-----------------------------------------------------------------------------
43+
# Set up TFLite in Play services C API (tensorflowlite_jni_gms_client) dependency.
44+
45+
find_package(tensorflowlite_jni_gms_client REQUIRED CONFIG)
46+
47+
#-----------------------------------------------------------------------------
48+
# Set up TFLite in Play services C++ API (tflite_cc_api) dependency.
49+
50+
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/Modules" ${CMAKE_MODULE_PATH})
51+
set(tflite_cc_api_DIR "${CMAKE_SOURCE_DIR}/tflite_cc_sdk")
52+
53+
find_package(tflite_cc_api REQUIRED MODULE)
54+
include_directories(${tflite_cc_api_INCLUDE_DIR})
55+
add_subdirectory(${tflite_cc_api_INCLUDE_DIR} tflite_cc_api_build)
56+
57+
#-----------------------------------------------------------------------------
58+
# Set up compile definitions to enable use of TFLite in Play services
59+
# (rather than regular TFLite bundled with the app).
60+
61+
add_compile_definitions(TFLITE_IN_GMSCORE)
62+
add_compile_definitions(TFLITE_WITH_STABLE_ABI)
63+
add_compile_definitions(TFLITE_USE_OPAQUE_DELEGATE)
64+
65+
#-----------------------------------------------------------------------------
66+
# Define how to build this sample app's native code,
67+
# which is embedded in a JNI library so that it can be
68+
# called from the sample app's Java code.
69+
70+
add_library(tflite-cc-sample-jni SHARED
71+
com_google_samples_gms_tflite_cc_TfLiteJni.cc
72+
logging_assert.h
73+
java_interop.h)
74+
75+
target_link_libraries(tflite-cc-sample-jni
76+
tensorflowlite_jni_gms_client::tensorflowlite_jni_gms_client
77+
tflite_cc_api::tflite_cc_api
78+
flatbuffers
79+
android
80+
log)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# This CMake code fragment determines the location of the tflite_cc_api package.
2+
set(tflite_cc_api_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/tflite_cc_sdk")
3+
include(FindPackageHandleStandardArgs)
4+
find_package_handle_standard_args(tflite_cc_api DEFAULT_MSG
5+
tflite_cc_api_INCLUDE_DIR)

0 commit comments

Comments
 (0)