Skip to content

Commit e1b4a5e

Browse files
authored
Improve builtin index_string (#127)
1 parent ff907f0 commit e1b4a5e

File tree

11 files changed

+26
-21
lines changed

11 files changed

+26
-21
lines changed

docs/applications/resources/transformed-columns.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ See <!-- CORTEX_VERSION_MINOR -->[`transformers.yaml`](https://github.com/cortex
5555
columns:
5656
col: class # the name of a string column
5757
args:
58-
index: ["t", "f"] # a value to be used as the index
58+
index: {"indexes": ["t", "f"], "reversed_index": ["t": 0, "f": 1]} # a value to be used as the index
5959

6060
- kind: transformed_column
6161
name: price_bucketized

docs/tutorial.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ Add to `app.yaml`:
219219
columns:
220220
text: class
221221
args:
222-
index: class_index
222+
indexes: class_index
223223
```
224224

225225
You can simplify the configuration for aggregates and transformed columns using [templates](applications/advanced/templates.md).

examples/iris/implementations/models/dnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ def create_estimator(run_config, model_config):
1212
return tf.estimator.DNNClassifier(
1313
feature_columns=feature_columns,
1414
hidden_units=model_config["hparams"]["hidden_units"],
15-
n_classes=len(model_config["aggregates"]["class_index"]),
15+
n_classes=len(model_config["aggregates"]["class_index"]["index"]),
1616
config=run_config,
1717
)

examples/iris/resources/transformed_columns.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@
4545
columns:
4646
text: class
4747
args:
48-
index: class_index
48+
indexes: class_index

examples/movie-ratings/resources/transformed_columns.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
columns:
1313
text: user_id
1414
args:
15-
index: user_id_index
15+
indexes: user_id_index
1616

1717
- kind: aggregate
1818
name: movie_id_index
@@ -28,4 +28,4 @@
2828
columns:
2929
text: movie_id
3030
args:
31-
index: movie_id_index
31+
indexes: movie_id_index

examples/reviews/resources/columns.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@
2525
columns:
2626
text: label
2727
args:
28-
index: label_index
28+
indexes: label_index

pkg/aggregators/aggregators.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,11 +317,12 @@
317317
# Enumerates the unique values in a string column and orders them by placing the unique strings in
318318
# list ordered by most frequent starting at the 0th index.
319319
# Works well in conjunction with transformers.index_string.
320-
# For example: An input column with the following values ['t', 'f', 't'] would return ['t', 'f'].
320+
# For example: An input column with the following values ['t', 'f', 't'] would return
321+
# {"index": ['t', 'f'], "reversed_index": {'t': 0, 'f': 1}}.
321322
- kind: aggregator
322323
name: index_string
323324
path: index_string.py
324-
output_type: [STRING]
325+
output_type: {"index": [STRING], "reversed_index": {STRING: INT}}
325326
inputs:
326327
columns:
327328
col: STRING_COLUMN

pkg/aggregators/index_string.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ def aggregate_spark(data, columns, args):
1717
from pyspark.ml.feature import StringIndexer
1818

1919
indexer = StringIndexer(inputCol=columns["col"])
20-
return indexer.fit(data).labels
20+
index = indexer.fit(data).labels
21+
reversed_index = {v: k for k, v in enumerate(index)}
22+
return {"index": index, "reversed_index": reversed_index}

pkg/transformers/index_string.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def transform_spark(data, columns, args, transformed_column_name):
1818
import pyspark.sql.functions as F
1919

2020
indexer = StringIndexerModel.from_labels(
21-
args["index"], inputCol=columns["text"], outputCol=transformed_column_name
21+
args["indexes"]["index"], inputCol=columns["text"], outputCol=transformed_column_name
2222
)
2323

2424
return indexer.transform(data).withColumn(
@@ -27,12 +27,11 @@ def transform_spark(data, columns, args, transformed_column_name):
2727

2828

2929
def transform_python(sample, args):
30-
for idx, label in enumerate(args["index"]):
31-
if label == sample["text"]:
32-
return idx
30+
if sample["text"] in args["indexes"]["reversed_index"]:
31+
return args["indexes"]["reversed_index"][sample["text"]]
3332

3433
raise Exception("Could not find {} in index: {}".format(sample["text"], args))
3534

3635

3736
def reverse_transform_python(transformed_value, args):
38-
return args["index"][transformed_value]
37+
return args["indexes"]["index"][transformed_value]

pkg/transformers/transformers.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
# Given labels, map the string column to its index in the labels array.
4444
# Example:
45-
# INPUT: labels = ['r', 'b', 'g'] column = ['r', 'b', 'g', 'g'],
45+
# INPUT: indexes = {"index": ['r', 'b', 'g'], "reversed_index": {'r': 0, 'b': 1, 'g': 2}} column = ['r', 'b', 'g', 'g'],
4646
# OUTPUT: [0, 1, 2, 2]
4747
- kind: transformer
4848
name: index_string
@@ -52,4 +52,4 @@
5252
columns:
5353
text: STRING_COLUMN
5454
args:
55-
index: [STRING]
55+
indexes: {"index": [STRING], "reversed_index": {STRING: INT}}

0 commit comments

Comments
 (0)