Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit ee70071

Browse files
wenleixfacebook-github-bot
authored andcommitted
Add clamp_list UDF (#397)
Summary: Pull Request resolved: #397 Basically apply clamp to each element in the list. Used for sparse feature preproc in recommendation domain. This UDF will be deprecated once TorchArrow supports lambda function. Reviewed By: dracifer, bearzx Differential Revision: D37370459 fbshipit-source-id: 8fb01a2ff3bf7cef1f199231d9693f0624907ecb
1 parent c70817b commit ee70071

File tree

4 files changed

+122
-2
lines changed

4 files changed

+122
-2
lines changed

csrc/velox/functions/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ set(
1313
rec/sigrid_hash.h
1414
rec/firstX.h
1515
rec/compute_score.h
16+
rec/clamp_list.h
1617
register_udf.cpp
17-
)
18+
)
1819

1920
set(
2021
TORCHARROW_UDF_LINK_LIBRARIES
2122
velox_functions_string
2223
velox_functions_prestosql
23-
)
24+
)
2425
set(TORCHARROW_UDF_COMPILE_DEFINITIONS)
2526

2627
if (USE_TORCH)

csrc/velox/functions/functions.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <velox/functions/Registerer.h>
1212
#include "numeric_functions.h"
1313
#include "rec/bucketize.h" // @manual
14+
#include "rec/clamp_list.h" // @manual
1415
#include "rec/compute_score.h" // @manual
1516
#include "rec/firstX.h" // @manual
1617
#include "rec/sigrid_hash.h" // @manual
@@ -368,6 +369,34 @@ inline void registerTorchArrowFunctions() {
368369
velox::Array<int64_t>,
369370
velox::Array<float>>({"get_score_max"});
370371

372+
velox::registerFunction<
373+
ClampListFunction,
374+
velox::Array<int32_t>,
375+
velox::Array<int32_t>,
376+
int32_t,
377+
int32_t>({"clamp_list"});
378+
379+
velox::registerFunction<
380+
ClampListFunction,
381+
velox::Array<int64_t>,
382+
velox::Array<int64_t>,
383+
int64_t,
384+
int64_t>({"clamp_list"});
385+
386+
velox::registerFunction<
387+
ClampListFunction,
388+
velox::Array<float>,
389+
velox::Array<float>,
390+
float,
391+
float>({"clamp_list"});
392+
393+
velox::registerFunction<
394+
ClampListFunction,
395+
velox::Array<double>,
396+
velox::Array<double>,
397+
double,
398+
double>({"clamp_list"});
399+
371400
// TODO: consider to refactor registration code with helper functions
372401
// to save some lines, like https://fburl.com/code/dk6zi7t3
373402

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cmath>
12+
#include "velox/functions/Udf.h"
13+
#include "velox/type/Type.h"
14+
15+
namespace facebook::torcharrow::functions {
16+
17+
// TODO: remove this function once lambda expression is supported in
18+
// TorchArrow
19+
template <typename T>
20+
struct ClampListFunction {
21+
VELOX_DEFINE_FUNCTION_TYPES(T);
22+
23+
template <typename TOutput, typename TInput, typename TElement>
24+
FOLLY_ALWAYS_INLINE void callNullFree(
25+
TOutput& result,
26+
const TInput& values,
27+
const TElement& lo,
28+
const TElement& hi) {
29+
VELOX_USER_CHECK_LE(lo, hi, "Lo > hi in clamp.");
30+
result.reserve(values.size());
31+
for (const auto& val : values) {
32+
result.push_back(std::clamp(val, lo, hi));
33+
}
34+
}
35+
};
36+
37+
} // namespace facebook::torcharrow::functions
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torcharrow as ta
10+
import torcharrow.dtypes as dt
11+
from torcharrow import functional
12+
13+
14+
class _TestClampBase(unittest.TestCase):
15+
@classmethod
16+
def setUpClass(cls):
17+
cls.base_df_list = ta.dataframe(
18+
{
19+
"int64": [[0, 1, 2, 3], [-100, 100, 10], [0, -1, -2, -3]],
20+
},
21+
dtype=dt.Struct(
22+
fields=[
23+
dt.Field("int64", dt.List(dt.int64)),
24+
]
25+
),
26+
)
27+
28+
cls.setUpTestCaseData()
29+
30+
@classmethod
31+
def setUpTestCaseData(cls):
32+
# Override in subclass
33+
# Python doesn't have native "abstract base test" support.
34+
# So use unittest.SkipTest to skip in base class: https://stackoverflow.com/a/59561905.
35+
raise unittest.SkipTest("abstract base test")
36+
37+
def test_clamp_list(self):
38+
df = type(self).df_list
39+
40+
self.assertEqual(
41+
list(functional.clamp_list(df["int64"], 0, 20)),
42+
[[0, 1, 2, 3], [0, 20, 10], [0, 0, 0, 0]],
43+
)
44+
45+
46+
class TestClampCpu(_TestClampBase):
47+
@classmethod
48+
def setUpTestCaseData(cls):
49+
cls.df_list = cls.base_df_list.copy()
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)