1818from unittest import mock
1919
2020from google .protobuf import text_format
21+ import numpy as np
2122import tensorflow as tf
2223
2324from tensorboard .plugins .hparams import _keras
2425from tensorboard .plugins .hparams import metadata
2526from tensorboard .plugins .hparams import plugin_data_pb2
2627from tensorboard .plugins .hparams import summary_v2 as hp
2728
28- # Stay on Keras 2 for now: https://github.com/keras-team/keras/issues/18467.
29- version_fn = getattr (tf .keras , "version" , None )
30- if version_fn and version_fn ().startswith ("3." ):
31- import tf_keras as keras # Keras 2
32- else :
33- keras = tf .keras # Keras 2
34-
35- tf .compat .v1 .enable_eager_execution ()
36-
3729
3830class CallbackTest (tf .test .TestCase ):
3931 def setUp (self ):
@@ -46,12 +38,12 @@ def _initialize_model(self, writer):
4638 "optimizer" : "adam" ,
4739 HP_DENSE_NEURONS : 8 ,
4840 }
49- self .model = keras .models .Sequential (
41+ self .model = tf . keras .models .Sequential (
5042 [
51- keras .layers .Dense (
43+ tf . keras .layers .Dense (
5244 self .hparams [HP_DENSE_NEURONS ], input_shape = (1 ,)
5345 ),
54- keras .layers .Dense (1 , activation = "sigmoid" ),
46+ tf . keras .layers .Dense (1 , activation = "sigmoid" ),
5547 ]
5648 )
5749 self .model .compile (loss = "mse" , optimizer = self .hparams ["optimizer" ])
@@ -69,7 +61,11 @@ def mock_time():
6961 initial_time = mock_time .time
7062 with mock .patch ("time.time" , mock_time ):
7163 self ._initialize_model (writer = self .logdir )
72- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
64+ self .model .fit (
65+ x = tf .constant ([(1 ,)]),
66+ y = tf .constant ([(2 ,)]),
67+ callbacks = [self .callback ],
68+ )
7369 final_time = mock_time .time
7470
7571 files = os .listdir (self .logdir )
@@ -142,7 +138,11 @@ def test_explicit_writer(self):
142138 filename_suffix = ".magic" ,
143139 )
144140 self ._initialize_model (writer = writer )
145- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
141+ self .model .fit (
142+ x = tf .constant ([(1 ,)]),
143+ y = tf .constant ([(2 ,)]),
144+ callbacks = [self .callback ],
145+ )
146146
147147 files = os .listdir (self .logdir )
148148 self .assertEqual (len (files ), 1 , files )
@@ -158,15 +158,27 @@ def test_non_eager_failure(self):
158158 with self .assertRaisesRegex (
159159 RuntimeError , "only supported in TensorFlow eager mode"
160160 ):
161- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
161+ self .model .fit (
162+ x = np .ones ((10 , 10 )),
163+ y = np .ones ((10 , 10 )),
164+ callbacks = [self .callback ],
165+ )
162166
163167 def test_reuse_failure (self ):
164168 self ._initialize_model (writer = self .logdir )
165- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
169+ self .model .fit (
170+ x = tf .constant ([(1 ,)]),
171+ y = tf .constant ([(2 ,)]),
172+ callbacks = [self .callback ],
173+ )
166174 with self .assertRaisesRegex (
167175 RuntimeError , "cannot be reused across training sessions"
168176 ):
169- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
177+ self .model .fit (
178+ x = tf .constant ([(1 ,)]),
179+ y = tf .constant ([(2 ,)]),
180+ callbacks = [self .callback ],
181+ )
170182
171183 def test_invalid_writer (self ):
172184 with self .assertRaisesRegex (
0 commit comments