@@ -459,6 +459,54 @@ def testSavedModelRaisesErrorIfArtifactsDirExistsAsAFile(self):
459459 ValueError , r'already exists as a file' ):
460460 conversion .save_keras_model (model , artifacts_dir )
461461
462+ def testTranslateBatchNormalizationV1ClassName (self ):
463+ # The config JSON of a model consisting of a "BatchNormalizationV1" layer.
464+ # pylint: disable=line-too-long
465+ json_object = json .loads (
466+ '{"class_name": "Sequential", "keras_version": "2.2.4-tf", "config": {"layers": [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "GlorotUniform", "config": {"dtype": "float32", "seed": null}}, "name": "dense", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "units": 10, "batch_input_shape": [null, 3], "use_bias": true, "activity_regularizer": null}}, {"class_name": "BatchNormalizationV1", "config": {"beta_constraint": null, "gamma_initializer": {"class_name": "Ones", "config": {"dtype": "float32"}}, "moving_mean_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "name": "batch_normalization_v1", "dtype": "float32", "trainable": true, "moving_variance_initializer": {"class_name": "Ones", "config": {"dtype": "float32"}}, "beta_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "scale": true, "axis": [1], "epsilon": 0.001, "gamma_constraint": null, "gamma_regularizer": null, "beta_regularizer": null, "momentum": 0.99, "center": true}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "GlorotUniform", "config": {"dtype": "float32", "seed": null}}, "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "linear", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "units": 1, "use_bias": true, "activity_regularizer": null}}], "name": "sequential"}, "backend": "tensorflow"}' )
467+ # pylint: enable=line-too-long
468+ conversion .translate_class_names (json_object )
469+ # Some class names should not have been changed be translate_class_names().
470+ self .assertEqual (json_object ['class_name' ], 'Sequential' )
471+ self .assertEqual (json_object ['keras_version' ], '2.2.4-tf' )
472+ self .assertEqual (json_object ['config' ]['layers' ][0 ]['class_name' ], 'Dense' )
473+ # The translation should have happend:
474+ # BatchNormalizationV1 --> BatchNormalization.
475+ self .assertEqual (
476+ json_object ['config' ]['layers' ][1 ]['class_name' ], 'BatchNormalization' )
477+ self .assertEqual (json_object ['config' ]['layers' ][2 ]['class_name' ], 'Dense' )
478+
479+ # Assert that converted JSON can be reconstituted as a model object.
480+ model = keras .models .model_from_json (json .dumps (json_object ))
481+ self .assertTrue (isinstance (model , keras .Sequential ))
482+ self .assertEqual (model .input_shape , (None , 3 ))
483+ self .assertEqual (model .output_shape , (None , 1 ))
484+ self .assertEqual (model .layers [0 ].units , 10 )
485+ self .assertEqual (model .layers [2 ].units , 1 )
486+
487+ def testTranslateUnifiedGRUAndLSTMClassName (self ):
488+ # The config JSON of a model consisting of a "UnifiedGRU" layer
489+ # and a "UnifiedLSTM" layer.
490+ # pylint: disable=line-too-long
491+ json_object = json .loads (
492+ '{"class_name": "Sequential", "keras_version": "2.2.4-tf", "config": {"layers": [{"class_name": "UnifiedGRU", "config": {"recurrent_activation": "sigmoid", "dtype": "float32", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "tanh", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "batch_input_shape": [null, 4, 3], "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "kernel_constraint": null, "time_major": false, "dropout": 0.0, "stateful": false, "reset_after": true, "recurrent_regularizer": null, "name": "unified_gru", "bias_constraint": null, "go_backwards": false, "implementation": 1, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}}, {"class_name": "UnifiedLSTM", "config": {"recurrent_activation": "sigmoid", "dtype": "float32", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "tanh", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 2, "unit_forget_bias": true, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "kernel_constraint": null, "time_major": false, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "unified_lstm", "bias_constraint": null, "go_backwards": false, "implementation": 1, "kernel_regularizer": null, "return_sequences": false, "recurrent_constraint": null}}], "name": "sequential"}, "backend": "tensorflow"}' )
493+ # pylint: enable=line-too-long
494+ conversion .translate_class_names (json_object )
495+ # Some class names should not have been changed be translate_class_names().
496+ self .assertEqual (json_object ['class_name' ], 'Sequential' )
497+ self .assertEqual (json_object ['keras_version' ], '2.2.4-tf' )
498+ # The translation should have happend:
499+ # UnifiedGRU --> GRU.
500+ # UnifiedLSTM --> LSTM.
501+ self .assertEqual (json_object ['config' ]['layers' ][0 ]['class_name' ], 'GRU' )
502+ self .assertEqual (json_object ['config' ]['layers' ][1 ]['class_name' ], 'LSTM' )
503+
504+ # Assert that converted JSON can be reconstituted as a model object.
505+ model = keras .models .model_from_json (json .dumps (json_object ))
506+ self .assertTrue (isinstance (model , keras .Sequential ))
507+ self .assertEqual (model .input_shape , (None , 4 , 3 ))
508+ self .assertEqual (model .output_shape , (None , 2 ))
509+
462510
463511if __name__ == '__main__' :
464512 unittest .main ()
0 commit comments