Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 45a4b88

Browse files
author
Ryan Sepassi
committed
Fix colab notebook
PiperOrigin-RevId: 179871302
1 parent 87bfac5 commit 45a4b88

File tree

1 file changed

+2
-50
lines changed

1 file changed

+2
-50
lines changed

tensor2tensor/notebooks/hello_t2t.ipynb

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
"import os\n",
8686
"import collections\n",
8787
"\n",
88+
"from tensor2tensor import models\n",
8889
"from tensor2tensor import problems\n",
8990
"from tensor2tensor.layers import common_layers\n",
9091
"from tensor2tensor.tpu import tpu_trainer_lib\n",
@@ -1540,55 +1541,6 @@
15401541
}
15411542
]
15421543
},
1543-
{
1544-
"metadata": {
1545-
"id": "a2cL8UwLaSYG",
1546-
"colab_type": "code",
1547-
"colab": {
1548-
"autoexec": {
1549-
"startup": false,
1550-
"wait_interval": 0
1551-
}
1552-
}
1553-
},
1554-
"source": [
1555-
"# This will eventually be available at\n",
1556-
"# tensor2tensor.metrics.create_eager_metrics\n",
1557-
"def create_eager_metrics(metric_names):\n",
1558-
" \"\"\"Create metrics accumulators and averager for Eager mode.\n",
1559-
"\n",
1560-
" Args:\n",
1561-
" metric_names: list<str> from tensor2tensor.metrics.Metrics\n",
1562-
"\n",
1563-
" Returns:\n",
1564-
" (accum_fn(predictions, targets) => None,\n",
1565-
" result_fn() => dict<str metric_name, float avg_val>\n",
1566-
" \"\"\"\n",
1567-
" metric_fns = dict(\n",
1568-
" [(name, metrics.METRICS_FNS[name]) for name in metric_names])\n",
1569-
" tfe_metrics = dict()\n",
1570-
"\n",
1571-
" for name in metric_names:\n",
1572-
" tfe_metrics[name] = tfe.metrics.Mean(name=name)\n",
1573-
"\n",
1574-
" def metric_accum(predictions, targets):\n",
1575-
" for name, metric_fn in metric_fns.items():\n",
1576-
" val, weight = metric_fn(predictions, targets,\n",
1577-
" weights_fn=common_layers.weights_all)\n",
1578-
" tfe_metrics[name](np.squeeze(val), np.squeeze(weight))\n",
1579-
"\n",
1580-
" def metric_means():\n",
1581-
" avgs = {}\n",
1582-
" for name in metric_names:\n",
1583-
" avgs[name] = tfe_metrics[name].result().numpy()\n",
1584-
" return avgs\n",
1585-
"\n",
1586-
" return metric_accum, metric_means"
1587-
],
1588-
"cell_type": "code",
1589-
"execution_count": 0,
1590-
"outputs": []
1591-
},
15921544
{
15931545
"metadata": {
15941546
"id": "CIFlkiVOd8jO",
@@ -1625,7 +1577,7 @@
16251577
"\n",
16261578
"# Create eval metric accumulators for accuracy (ACC) and accuracy in\n",
16271579
"# top 5 (ACC_TOP5)\n",
1628-
"metrics_accum, metrics_result = create_eager_metrics(\n",
1580+
"metrics_accum, metrics_result = metrics.create_eager_metrics(\n",
16291581
" [metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5])\n",
16301582
"\n",
16311583
"for count, example in enumerate(tfe.Iterator(mnist_eval_dataset)):\n",

0 commit comments

Comments
 (0)