Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 58a9e12

Browse files
vancexuwenleix
authored andcommitted
Add scale_to_z_score (#432)
Summary: Pull Request resolved: #432 scale_to_z_score is a common transform during preproc. The implementation is similar to tft reference here https://www.tensorflow.org/tfx/transform/api_docs/python/tft/scale_to_z_score. Reviewed By: Tianshu-Bao Differential Revision: D37771097 fbshipit-source-id: b0fbe28af7b768ac857ac24f15727b981ee17262
1 parent 9250038 commit 58a9e12

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

torcharrow/functional.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from types import ModuleType
1010
from typing import Dict, List, Optional, Set, Union
1111

12+
import torcharrow.dtypes as dt
13+
1214
from torcharrow.icolumn import Column
1315
from torcharrow.ilist_column import ListColumn
1416
from torcharrow.inumerical_column import NumericalColumn
@@ -508,3 +510,20 @@ def scale_to_0_1(col: NumericalColumn) -> NumericalColumn:
508510
else:
509511
# TODO: we should add explicit stub to sigmoid
510512
return sys.modules["torcharrow.functional"].sigmoid(col)
513+
514+
515+
def scale_to_z_score(col: NumericalColumn) -> NumericalColumn:
516+
"""
517+
Return the column data scaled to mean 0 and variance 1 (standard deviation 1).
518+
Scaling to z-score subtracts out the mean and divides by standard deviation.
519+
Note that the standard deviation computed here is based on the biased variance (0 delta degrees of freedom).
520+
If input col contains a single distinct value, then the input is returned without scaling.
521+
If input col is integral, the output is cast to float32.
522+
"""
523+
assert isinstance(col, NumericalColumn)
524+
std = col.std()
525+
if std == 0:
526+
if dt.is_integer(col.dtype):
527+
return col.cast(dt.Float32(col.dtype.nullable))
528+
return col
529+
return (col - col.mean()) / std

torcharrow/test/test_functional_cpu.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import unittest
88

9+
import numpy as np
10+
911
import torcharrow as ta
1012
import torcharrow._torcharrow
1113
import torcharrow.dtypes as dt
@@ -96,6 +98,58 @@ def test_scale_to_0_1(self):
9698
with self.assertRaises(AssertionError):
9799
functional.scale_to_0_1(c)
98100

101+
def test_scale_to_z_score(self):
102+
# norm same int
103+
c = ta.column([1, 1], device=self.device)
104+
self.assertEqual(c.dtype, dt.int64)
105+
result = functional.scale_to_z_score(c)
106+
self.assertEqual(
107+
list(result),
108+
[1, 1],
109+
)
110+
self.assertEqual(result.dtype, dt.float32)
111+
112+
# norm same double
113+
c = ta.column([np.float64(1), np.float64(1)], device=self.device)
114+
self.assertEqual(c.dtype, dt.float64)
115+
result = functional.scale_to_z_score(c)
116+
self.assertEqual(
117+
list(result),
118+
[1, 1],
119+
)
120+
self.assertEqual(result.dtype, dt.float64)
121+
122+
# norm float
123+
c = ta.column([1.0, 1.0, 2.0, 2.0], device=self.device)
124+
self.assertEqual(
125+
list(functional.scale_to_z_score(c)),
126+
[
127+
-0.866025447845459,
128+
-0.866025447845459,
129+
0.866025447845459,
130+
0.866025447845459,
131+
],
132+
)
133+
134+
# norm int with None
135+
c = ta.column([1, 2, 3, None, 4, 5], device=self.device)
136+
self.assertEqual(
137+
list(functional.scale_to_z_score(c)),
138+
[
139+
-1.2649110555648804,
140+
-0.6324555277824402,
141+
0.0,
142+
None,
143+
0.6324555277824402,
144+
1.2649110555648804,
145+
],
146+
)
147+
148+
# test assert
149+
c = ta.column(["foo", "bar"])
150+
with self.assertRaises(AssertionError):
151+
functional.scale_to_z_score(c)
152+
99153

100154
if __name__ == "__main__":
101155
unittest.main()

0 commit comments

Comments
 (0)