Skip to content

Commit 4a67d14

Browse files
authored
Transformer model sentiment analysis example (#36)
1 parent 32e3a33 commit 4a67d14

File tree

21 files changed

+152
-264
lines changed

21 files changed

+152
-264
lines changed

docs/applications/implementations/aggregators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def aggregate_spark(data, columns, args):
4242
The following packages have been pre-installed and can be used in your implementations:
4343

4444
```text
45-
pyspark==2.4.0
45+
pyspark==2.4.1
4646
boto3==1.9.78
4747
msgpack==0.6.1
4848
numpy>=1.13.3,<2

docs/applications/implementations/transformers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def reverse_transform_python(transformed_value, args):
8686
The following packages have been pre-installed and can be used in your implementations:
8787

8888
```text
89-
pyspark==2.4.0
89+
pyspark==2.4.1
9090
boto3==1.9.78
9191
msgpack==0.6.1
9292
numpy>=1.13.3,<2

docs/applications/resources/environments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ data:
3535
3636
#### CSV Config
3737
38-
To help ingest different styles of CSV files, Cortex supports the parameters listed below. All of these parameters are optional. A description and default values for each parameter can be found in the [PySpark CSV Documentation](https://spark.apache.org/docs/2.4.0/api/python/pyspark.sql.html#pyspark.sql.DataFrameReader.csv).
38+
To help ingest different styles of CSV files, Cortex supports the parameters listed below. All of these parameters are optional. A description and default values for each parameter can be found in the [PySpark CSV Documentation](https://spark.apache.org/docs/2.4.1/api/python/pyspark.sql.html#pyspark.sql.DataFrameReader.csv).
3939
4040
```yaml
4141
csv_config:

examples/mnist/implementations/models/t2t.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ def transform_tensorflow(features, labels, model_config):
3535
features["inputs"] = tf.reshape(features["image_pixels"], hparams["input_shape"])
3636

3737
# t2t expects this key and dimensionality
38-
features["targets"] = tf.expand_dims(labels, 0)
38+
features["targets"] = tf.expand_dims(tf.expand_dims(labels, -1), -1)
3939

4040
return features, labels

examples/reviews/implementations/aggregators/max_length.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
def aggregate_spark(data, columns, args):
2-
from pyspark.ml.feature import StopWordsRemover, RegexTokenizer
2+
from pyspark.ml.feature import RegexTokenizer
33
import pyspark.sql.functions as F
44
from pyspark.sql.types import IntegerType
55

66
regexTokenizer = RegexTokenizer(inputCol=columns["col"], outputCol="token_list", pattern="\\W")
77
regexTokenized = regexTokenizer.transform(data)
88

9-
remover = StopWordsRemover(inputCol="token_list", outputCol="filtered_word_list")
109
max_review_length_row = (
11-
remover.transform(regexTokenized)
12-
.select(F.size(F.col("filtered_word_list")).alias("word_count"))
10+
regexTokenized.select(F.size(F.col("token_list")).alias("word_count"))
1311
.agg(F.max(F.col("word_count")).alias("max_review_length"))
1412
.collect()
1513
)
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
def aggregate_spark(data, columns, args):
22
import pyspark.sql.functions as F
3-
from pyspark.ml.feature import StopWordsRemover, RegexTokenizer
3+
from pyspark.ml.feature import RegexTokenizer
44

5-
input_data = data.withColumn(columns["col"], F.lower(F.col(columns["col"])))
65
regexTokenizer = RegexTokenizer(inputCol=columns["col"], outputCol="token_list", pattern="\\W")
76
regexTokenized = regexTokenizer.transform(data)
87

9-
remover = StopWordsRemover(inputCol="token_list", outputCol="filtered_word_list")
108
vocab_rows = (
11-
remover.transform(regexTokenized)
12-
.select(F.explode(F.col("filtered_word_list")).alias("word"))
9+
regexTokenized.select(F.explode(F.col("token_list")).alias("word"))
1310
.groupBy("word")
1411
.count()
1512
.orderBy(F.col("count").desc())
@@ -19,6 +16,7 @@ def aggregate_spark(data, columns, args):
1916
)
2017

2118
vocab = [row["word"] for row in vocab_rows]
22-
reverse_dict = {word: idx + len(args["reserved_indices"]) for idx, word in enumerate(vocab)}
23-
24-
return {**reverse_dict, **args["reserved_indices"]}
19+
reverse_dict = {word: 2 + idx for idx, word in enumerate(vocab)}
20+
reverse_dict["<PAD>"] = 0
21+
reverse_dict["<UNKNOWN>"] = 1
22+
return reverse_dict
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import tensorflow as tf
2+
from tensor2tensor.utils import trainer_lib
3+
from tensor2tensor import models # pylint: disable=unused-import
4+
from tensor2tensor import problems # pylint: disable=unused-import
5+
from tensor2tensor.data_generators import problem_hparams
6+
from tensor2tensor.utils import registry
7+
from tensor2tensor.utils import metrics
8+
from tensor2tensor.data_generators import imdb
9+
from tensor2tensor.data_generators import text_encoder
10+
11+
12+
def create_estimator(run_config, model_config):
13+
# t2t expects these keys in run_config
14+
run_config.data_parallelism = None
15+
run_config.t2t_device_info = {"num_async_replicas": 1}
16+
17+
hparams = trainer_lib.create_hparams("transformer_base_single_gpu")
18+
19+
problem = SentimentIMDBCortex(list(model_config["aggregates"]["reviews_vocab"]))
20+
p_hparams = problem.get_hparams(hparams)
21+
hparams.problem = problem
22+
hparams.problem_hparams = p_hparams
23+
24+
problem.eval_metrics = lambda: [
25+
metrics.Metrics.ACC_TOP5,
26+
metrics.Metrics.ACC_PER_SEQ,
27+
metrics.Metrics.NEG_LOG_PERPLEXITY,
28+
]
29+
30+
# t2t expects this key
31+
hparams.warm_start_from = None
32+
33+
# reduce memory load
34+
hparams.num_hidden_layers = 2
35+
hparams.hidden_size = 32
36+
hparams.filter_size = 32
37+
hparams.num_heads = 2
38+
39+
estimator = trainer_lib.create_estimator("transformer", hparams, run_config)
40+
return estimator
41+
42+
43+
def transform_tensorflow(features, labels, model_config):
44+
max_length = model_config["aggregates"]["max_review_length"]
45+
46+
features["inputs"] = tf.expand_dims(tf.reshape(features["embedding_input"], [max_length]), -1)
47+
features["targets"] = tf.expand_dims(tf.expand_dims(labels, -1), -1)
48+
49+
return features, labels
50+
51+
52+
class SentimentIMDBCortex(imdb.SentimentIMDB):
53+
"""IMDB sentiment classification, with an in-memory vocab"""
54+
55+
def __init__(self, vocab_list):
56+
super().__init__()
57+
self.vocab = vocab_list
58+
59+
def feature_encoders(self, data_dir):
60+
encoder = text_encoder.TokenTextEncoder(vocab_filename=None, vocab_list=self.vocab)
61+
62+
return {
63+
"inputs": encoder,
64+
"targets": text_encoder.ClassLabelEncoder(self.class_labels(data_dir)),
65+
}

examples/reviews/implementations/transformers/tokenize_string_to_int.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,16 @@
66
def transform_python(sample, args):
77
text = sample["col"].lower()
88
token_index_list = []
9-
10-
reverse_vocab = args["vocab"]
11-
stop_words = args["stop_words"]
12-
reserved_indices = args["reserved_indices"]
9+
vocab = args["vocab"]
1310

1411
for token in non_word.split(text):
1512
if len(token) == 0:
1613
continue
17-
if token in stop_words:
18-
continue
19-
token_index_list.append(reverse_vocab.get(token, reserved_indices["<UNKNOWN>"]))
14+
token_index_list.append(vocab.get(token, vocab["<UNKNOWN>"]))
2015
if len(token_index_list) == args["max_len"]:
2116
break
2217

2318
for i in range(args["max_len"] - len(token_index_list)):
24-
token_index_list.append(reserved_indices["<PAD>"])
19+
token_index_list.append(vocab["<PAD>"])
2520

2621
return token_index_list

examples/reviews/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensor2tensor==1.10.0

examples/reviews/resources/aggregators.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)