@@ -775,6 +775,30 @@ def func(x):
775775 return tf .identity (y [0 ], name = "output" ), tf .identity (y [1 ], name = "output1" )
776776 self .run_test_case (func , {"input:0" : x_val }, [], ["output:0" , "output1:0" ], rtol = 1e-05 , atol = 1e-06 )
777777
778+ @check_tf_min_version ("2.0" )
779+ def test_keras_bilstm_recurrent_activation_is_hard_sigmoid (self ):
780+ in_shape = [10 , 3 ]
781+ x_val = np .random .uniform (size = [2 , 10 , 3 ]).astype (np .float32 )
782+
783+ model_in = tf .keras .layers .Input (tuple (in_shape ), batch_size = 2 )
784+ x = tf .keras .layers .Bidirectional (
785+ tf .keras .layers .LSTM (
786+ units = 5 ,
787+ return_sequences = True ,
788+ return_state = True ,
789+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
790+ recurrent_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
791+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ),
792+ recurrent_activation = "hard_sigmoid" ,
793+ )
794+ )(model_in )
795+ model = tf .keras .models .Model (inputs = model_in , outputs = x )
796+
797+ def func (x ):
798+ y = model (x )
799+ return tf .identity (y [0 ], name = "output" ), tf .identity (y [1 ], name = "output1" )
800+ self .run_test_case (func , {"input:0" : x_val }, [], ["output:0" , "output1:0" ], rtol = 1e-05 , atol = 1e-06 )
801+
778802 @check_tf_min_version ("2.0" )
779803 @skip_tfjs ("TFJS converts model incorrectly" )
780804 def test_keras_lstm_sigmoid_dropout (self ):
0 commit comments