Skip to content

Commit 3376621

Browse files
authored
keras_format: fix serialized Optimizer class_names (#569)
1 parent 91a9feb commit 3376621

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

src/keras_format/optimizer_config.ts

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,22 @@ export type AdadeltaOptimizerConfig = {
2828
};
2929

3030
export type AdadeltaSerialization =
31-
BaseSerialization<'AdadeltaOptimizer', AdadeltaOptimizerConfig>;
31+
BaseSerialization<'Adadelta', AdadeltaOptimizerConfig>;
3232

3333
export type AdagradOptimizerConfig = {
3434
learning_rate: number;
3535
initial_accumulator_value?: number;
3636
};
3737

3838
export type AdagradSerialization =
39-
BaseSerialization<'AdagradOptimizer', AdagradOptimizerConfig>;
39+
BaseSerialization<'Adagrad', AdagradOptimizerConfig>;
4040

4141
export type AdamOptimizerConfig = {
4242
learning_rate: number; beta1: number; beta2: number;
4343
epsilon?: number;
4444
};
4545

46-
export type AdamSerialization =
47-
BaseSerialization<'AdamOptimizer', AdamOptimizerConfig>;
46+
export type AdamSerialization = BaseSerialization<'Adam', AdamOptimizerConfig>;
4847

4948
export type AdamaxOptimizerConfig = {
5049
learning_rate: number; beta1: number; beta2: number;
@@ -53,7 +52,7 @@ export type AdamaxOptimizerConfig = {
5352
};
5453

5554
export type AdamaxSerialization =
56-
BaseSerialization<'AdamaxOptimizer', AdamaxOptimizerConfig>;
55+
BaseSerialization<'Adamax', AdamaxOptimizerConfig>;
5756

5857
export type MomentumOptimizerConfig = {
5958
// extends SGDOptimizerConfig {
@@ -62,7 +61,7 @@ export type MomentumOptimizerConfig = {
6261
};
6362

6463
export type MomentumSerialization =
65-
BaseSerialization<'MomentumOptimizer', MomentumOptimizerConfig>;
64+
BaseSerialization<'Momentum', MomentumOptimizerConfig>;
6665

6766
export type RMSPropOptimizerConfig = {
6867
learning_rate: number;
@@ -73,14 +72,13 @@ export type RMSPropOptimizerConfig = {
7372
};
7473

7574
export type RMSPropSerialization =
76-
BaseSerialization<'RMSPropOptimizer', RMSPropOptimizerConfig>;
75+
BaseSerialization<'RMSProp', RMSPropOptimizerConfig>;
7776

7877
export type SGDOptimizerConfig = {
7978
learning_rate: number;
8079
};
8180

82-
export type SGDSerialization =
83-
BaseSerialization<'SGDOptimizer', SGDOptimizerConfig>;
81+
export type SGDSerialization = BaseSerialization<'SGD', SGDOptimizerConfig>;
8482

8583
// Update optimizerClassNames below in concert with this.
8684
export type OptimizerSerialization = AdadeltaSerialization|AdagradSerialization|
@@ -97,7 +95,5 @@ export type OptimizerClassName = OptimizerSerialization['class_name'];
9795
*
9896
* This is guaranteed to match the `OptimizerClassName` union type.
9997
*/
100-
export const optimizerClassNames: OptimizerClassName[] = [
101-
'AdadeltaOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer',
102-
'MomentumOptimizer', 'RMSPropOptimizer', 'SGDOptimizer'
103-
];
98+
export const optimizerClassNames: OptimizerClassName[] =
99+
['Adadelta', 'Adagrad', 'Adam', 'Adamax', 'Momentum', 'RMSProp', 'SGD'];

0 commit comments

Comments
 (0)