Skip to content

Commit b2e1d49

Browse files
authored
Add out-variants to support ET export
Differential Revision: D62385428 Pull Request resolved: #859
1 parent a584e24 commit b2e1d49

25 files changed

+743
-387
lines changed

torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
// LICENSE file in the root directory of this source tree.
66

77
#pragma once
8+
#include <stdint.h>
89
#include <torchao/experimental/kernels/cpu/macro.h>
910
#include <torchao/experimental/kernels/cpu/parallel.h>
11+
#include <algorithm>
1012
#include <cassert>
1113
#include <cstdlib>
1214

1315
namespace torchao::operators::cpu::linear::
1416
channelwise_8bit_activation_groupwise_lowbit_weight {
1517

16-
PackWeightDataTilingParams get_default_pack_weight_data_tiling_params(
18+
inline PackWeightDataTilingParams get_default_pack_weight_data_tiling_params(
1719
const UKernelConfig& ukernel_config,
1820
int n,
1921
int target_panels_per_thread) {
@@ -38,7 +40,7 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params(
3840
return tiling_params;
3941
}
4042

41-
void pack_weight_data_operator(
43+
inline void pack_weight_data_operator(
4244
const UKernelConfig& ukernel_config,
4345
const PackWeightDataTilingParams& tiling_params,
4446
// Outputs
@@ -79,7 +81,7 @@ void pack_weight_data_operator(
7981
}
8082

8183
// This default mimics XNNPACK behavior if target_tiles_per_thread = 5
82-
LinearTilingParams get_default_linear_tiling_params(
84+
inline LinearTilingParams get_default_linear_tiling_params(
8385
const UKernelConfig& ukernel_config,
8486
int m,
8587
int n,
@@ -137,12 +139,12 @@ get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_n
137139
return ukernel_config.activation_data_size_fn(m, k, group_size);
138140
}
139141

140-
void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
142+
inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
141143
const UKernelConfig& ukernel_config,
142144
const LinearTilingParams& tiling_params,
143145
char* activation_data_buffer,
144146
// Outputs
145-
float32_t* output,
147+
float* output,
146148
// Inputs
147149
int m,
148150
int n,
@@ -199,12 +201,12 @@ void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
199201
}
200202
}
201203

202-
void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
204+
inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
203205
const UKernelConfig& ukernel_config,
204206
const LinearTilingParams& tiling_params,
205207
char* activation_data_buffer,
206208
// Outputs
207-
float32_t* output,
209+
float* output,
208210
// Inputs
209211
int m,
210212
int n,
@@ -271,7 +273,7 @@ void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
271273
}
272274
} // namespace internal
273275

274-
void linear_operator(
276+
inline void linear_operator(
275277
const UKernelConfig& ukernel_config,
276278
const LinearTilingParams& tiling_params,
277279
LinearTileSchedulingPolicy scheduling_policy,
@@ -363,7 +365,7 @@ namespace torchao::operators::cpu::linear::
363365
channelwise_8bit_activation_groupwise_lowbit_weight {
364366
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
365367

366-
UKernelConfig get_ukernel_config() {
368+
inline UKernelConfig get_ukernel_config() {
367369
UKernelConfig config;
368370

369371
namespace ukernel = torchao::kernels::cpu::aarch64::linear::

torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// LICENSE file in the root directory of this source tree.
66

77
#pragma once
8+
#include <stdint.h>
89

910
// TODO: maybe move to operator directory
1011
namespace torchao::operators::cpu::linear::

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,41 @@ include_directories(${TORCHAO_LIBRARIES})
1818

1919
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
2020

21-
find_package(Torch REQUIRED)
22-
include_directories("${TORCH_INCLUDE_DIRS}")
21+
include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
2322

24-
add_library(torch_custom_op SHARED torch_custom_op.cpp)
25-
target_link_libraries(torch_custom_op PRIVATE "${TORCH_LIBRARIES}")
26-
target_link_libraries(torch_custom_op PRIVATE kernel_aarch64)
23+
set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH")
24+
string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER)
2725

28-
include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
29-
set(TORCHAO_PARALLEL_BACKEND "ATEN_OPENMP" CACHE STRING "Choose parallel backend to use for torchao parallelism (aten_openmp, openmp, pthreadpool, single_threaded)")
30-
target_link_torchao_parallel_backend(torch_custom_op "${TORCHAO_PARALLEL_BACKEND}")
26+
if(PLATFORM_TO_UPPER STREQUAL "ATEN")
27+
message(STATUS "Building with PLATFORM=ATEN")
28+
29+
find_package(Torch REQUIRED)
30+
add_library(lowbit_op_aten SHARED lowbit_op_aten.cpp)
31+
target_link_libraries(lowbit_op_aten PRIVATE kernel_aarch64)
32+
target_include_directories(lowbit_op_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
33+
target_link_libraries(lowbit_op_aten PRIVATE "${TORCH_LIBRARIES}")
34+
target_compile_definitions(lowbit_op_aten PRIVATE USE_ATEN=1)
35+
target_link_torchao_parallel_backend(lowbit_op_aten "ATEN_OPENMP")
36+
37+
elseif(PLATFORM_TO_UPPER STREQUAL "EXECUTORCH")
38+
message(STATUS "Building with PLATFORM=EXECUTORCH")
39+
40+
add_library(lowbit_op_executorch SHARED
41+
lowbit_op_executorch/w2s.cpp
42+
lowbit_op_executorch/w2sz.cpp
43+
lowbit_op_executorch/w3s.cpp
44+
lowbit_op_executorch/w3sz.cpp
45+
lowbit_op_executorch/w4s.cpp
46+
lowbit_op_executorch/w4sz.cpp
47+
lowbit_op_executorch/w5s.cpp
48+
lowbit_op_executorch/w5sz.cpp
49+
)
50+
target_include_directories(lowbit_op_executorch PRIVATE ${EXECUTORCH_INCLUDE_DIRS})
51+
target_compile_definitions(lowbit_op_executorch PRIVATE USE_EXECUTORCH=1)
52+
target_link_torchao_parallel_backend(lowbit_op_executorch "SINGLE_THREADED")
53+
target_link_libraries(lowbit_op_executorch PRIVATE ${EXECUTORCH_LIBRARIES})
54+
target_link_libraries(lowbit_op_executorch PRIVATE kernel_aarch64)
55+
56+
else()
57+
message(FATAL_ERROR "Unknown PLATFORM: ${PLATFORM}. Please choose one of: ATEN, EXECUTORCH.")
58+
endif()

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
1313
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
1414
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
1515
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
16-
-DTORCHAO_PARALLEL_BACKEND="aten_openmp" \
16+
-DPLATFORM="ATEN" \
1717
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
1818
-B ${CMAKE_OUT}
1919
cmake --build ${CMAKE_OUT}

0 commit comments

Comments
 (0)