|
85 | 85 | "import os\n", |
86 | 86 | "import collections\n", |
87 | 87 | "\n", |
| 88 | + "from tensor2tensor import models\n", |
88 | 89 | "from tensor2tensor import problems\n", |
89 | 90 | "from tensor2tensor.layers import common_layers\n", |
90 | 91 | "from tensor2tensor.tpu import tpu_trainer_lib\n", |
|
1540 | 1541 | } |
1541 | 1542 | ] |
1542 | 1543 | }, |
1543 | | - { |
1544 | | - "metadata": { |
1545 | | - "id": "a2cL8UwLaSYG", |
1546 | | - "colab_type": "code", |
1547 | | - "colab": { |
1548 | | - "autoexec": { |
1549 | | - "startup": false, |
1550 | | - "wait_interval": 0 |
1551 | | - } |
1552 | | - } |
1553 | | - }, |
1554 | | - "source": [ |
1555 | | - "# This will eventually be available at\n", |
1556 | | - "# tensor2tensor.metrics.create_eager_metrics\n", |
1557 | | - "def create_eager_metrics(metric_names):\n", |
1558 | | - " \"\"\"Create metrics accumulators and averager for Eager mode.\n", |
1559 | | - "\n", |
1560 | | - " Args:\n", |
1561 | | - " metric_names: list<str> from tensor2tensor.metrics.Metrics\n", |
1562 | | - "\n", |
1563 | | - " Returns:\n", |
1564 | | - " (accum_fn(predictions, targets) => None,\n", |
1565 | | - " result_fn() => dict<str metric_name, float avg_val>\n", |
1566 | | - " \"\"\"\n", |
1567 | | - " metric_fns = dict(\n", |
1568 | | - " [(name, metrics.METRICS_FNS[name]) for name in metric_names])\n", |
1569 | | - " tfe_metrics = dict()\n", |
1570 | | - "\n", |
1571 | | - " for name in metric_names:\n", |
1572 | | - " tfe_metrics[name] = tfe.metrics.Mean(name=name)\n", |
1573 | | - "\n", |
1574 | | - " def metric_accum(predictions, targets):\n", |
1575 | | - " for name, metric_fn in metric_fns.items():\n", |
1576 | | - " val, weight = metric_fn(predictions, targets,\n", |
1577 | | - " weights_fn=common_layers.weights_all)\n", |
1578 | | - " tfe_metrics[name](np.squeeze(val), np.squeeze(weight))\n", |
1579 | | - "\n", |
1580 | | - " def metric_means():\n", |
1581 | | - " avgs = {}\n", |
1582 | | - " for name in metric_names:\n", |
1583 | | - " avgs[name] = tfe_metrics[name].result().numpy()\n", |
1584 | | - " return avgs\n", |
1585 | | - "\n", |
1586 | | - " return metric_accum, metric_means" |
1587 | | - ], |
1588 | | - "cell_type": "code", |
1589 | | - "execution_count": 0, |
1590 | | - "outputs": [] |
1591 | | - }, |
1592 | 1544 | { |
1593 | 1545 | "metadata": { |
1594 | 1546 | "id": "CIFlkiVOd8jO", |
|
1625 | 1577 | "\n", |
1626 | 1578 | "# Create eval metric accumulators for accuracy (ACC) and accuracy in\n", |
1627 | 1579 | "# top 5 (ACC_TOP5)\n", |
1628 | | - "metrics_accum, metrics_result = create_eager_metrics(\n", |
| 1580 | + "metrics_accum, metrics_result = metrics.create_eager_metrics(\n", |
1629 | 1581 | " [metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5])\n", |
1630 | 1582 | "\n", |
1631 | 1583 | "for count, example in enumerate(tfe.Iterator(mnist_eval_dataset)):\n", |
|
0 commit comments