|
10 | 10 |
|
11 | 11 |
|
12 | 12 | def create_estimator(run_config, model_config): |
13 | | - # t2t expects these keys in run_config |
14 | | - run_config.data_parallelism = None |
15 | | - run_config.t2t_device_info = {"num_async_replicas": 1} |
16 | | - |
17 | 13 | hparams = trainer_lib.create_hparams("transformer_base_single_gpu") |
18 | 14 |
|
| 15 | + # SentimentIMDBCortex subclasses SentimentIMDB |
19 | 16 | problem = SentimentIMDBCortex(list(model_config["aggregates"]["reviews_vocab"])) |
20 | | - p_hparams = problem.get_hparams(hparams) |
21 | 17 | hparams.problem = problem |
22 | | - hparams.problem_hparams = p_hparams |
| 18 | + hparams.problem_hparams = problem.get_hparams(hparams) |
23 | 19 |
|
| 20 | + # metrics specific to the sentiment problem |
24 | 21 | problem.eval_metrics = lambda: [ |
25 | 22 | metrics.Metrics.ACC_TOP5, |
26 | 23 | metrics.Metrics.ACC_PER_SEQ, |
27 | 24 | metrics.Metrics.NEG_LOG_PERPLEXITY, |
28 | 25 | ] |
29 | 26 |
|
30 | | - # t2t expects this key |
31 | | - hparams.warm_start_from = None |
32 | | - |
33 | 27 | # reduce memory load |
34 | 28 | hparams.num_hidden_layers = 2 |
35 | 29 | hparams.hidden_size = 32 |
36 | 30 | hparams.filter_size = 32 |
37 | 31 | hparams.num_heads = 2 |
38 | 32 |
|
39 | | - estimator = trainer_lib.create_estimator("transformer", hparams, run_config) |
40 | | - return estimator |
| 33 | + # t2t expects these keys |
| 34 | + hparams.warm_start_from = None |
| 35 | + run_config.data_parallelism = None |
| 36 | + run_config.t2t_device_info = {"num_async_replicas": 1} |
| 37 | + |
| 38 | + return trainer_lib.create_estimator("transformer", hparams, run_config) |
41 | 39 |
|
42 | 40 |
|
43 | 41 | def transform_tensorflow(features, labels, model_config): |
|
0 commit comments