Skip to content

Commit b4d9c89

Browse files
Fix Progbar.update when receiving list, np arrays, and tensors. (#21823)
* Fix `Progbar.update` when encountering list, np array and tensor. * Resolve comments.
1 parent d8e0b4a commit b4d9c89

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

keras/src/utils/progbar.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
import time
55

6+
import numpy as np
7+
68
from keras.src.api_export import keras_export
79
from keras.src.utils import io_utils
810

@@ -161,7 +163,10 @@ def update(self, current, values=None, finalize=None):
161163
for k in self._values_order:
162164
info += f" - {k}:"
163165
if isinstance(self._values[k], list):
164-
avg = self._values[k][0] / max(1, self._values[k][1])
166+
values, count = self._values[k]
167+
if not isinstance(values, float):
168+
values = np.mean(values)
169+
avg = values / max(1, count)
165170
if abs(avg) > 1e-3:
166171
info += f" {avg:.4f}"
167172
else:
@@ -188,7 +193,10 @@ def update(self, current, values=None, finalize=None):
188193
info += f" -{self._format_time(time_per_unit, self.unit_name)}"
189194
for k in self._values_order:
190195
info += f" - {k}:"
191-
avg = self._values[k][0] / max(1, self._values[k][1])
196+
values, count = self._values[k]
197+
if not isinstance(values, float):
198+
values = np.mean(values)
199+
avg = values / max(1, count)
192200
if avg > 1e-3:
193201
info += f" {avg:.4f}"
194202
else:

keras/src/utils/progbar_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
from absl.testing import parameterized
3+
4+
from keras.src import testing
5+
from keras.src.utils import progbar
6+
7+
8+
class ProgbarTest(testing.TestCase):
9+
@parameterized.named_parameters(
10+
[
11+
("float", "float"),
12+
("np", "np"),
13+
("list", "list"),
14+
]
15+
)
16+
def test_update(self, value_type):
17+
if value_type == "float":
18+
values = 1.0
19+
elif value_type == "np":
20+
values = np.array(1.0)
21+
elif value_type == "list":
22+
values = [0.0, 1.0, 2.0]
23+
else:
24+
raise ValueError("Unknown value_type")
25+
pb = progbar.Progbar(target=1, verbose=1)
26+
27+
pb.update(1, values=[("values", values)], finalize=True)

0 commit comments

Comments
 (0)