Skip to content

Commit 3a3728f

Browse files
authored
Add support for user-defined metadata; Simplify fromMemory() usage (#571)
FEATURE - Add two public methods to `tf.LayersModel` and `tf.Sequential` - `setUserDefinedMetadata()` - `getUserDefinedMetadata()` These methods allow user to set and get custom metadata about the model. - Any set user metadata is serialized together with the model. - User-defined metadata are also deserialized (i.e., loaded) together with the topology and/or weights of the model. - The user-defined metadata is required to be a plain JSON object. This PR implements the enforcement mechanisms that will throw errors if this condition is not met. - During serialization and deserializatoin, the size of the JSON object will be checked. If it is greater than 1 MB in length, a warning will be thrown. Towards #1596
1 parent e47a877 commit 3a3728f

File tree

6 files changed

+359
-69
lines changed

6 files changed

+359
-69
lines changed

src/engine/training.ts

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import * as losses from '../losses';
2727
import * as Metrics from '../metrics';
2828
import * as optimizers from '../optimizers';
2929
import {LossOrMetricFn} from '../types';
30+
import {checkUserDefinedMetadata} from '../user_defined_metadata';
3031
import {count, pyListRepeat, singletonOrArray, toCamelCase, toSnakeCase, unique} from '../utils/generic_utils';
3132
import {printSummary} from '../utils/layer_utils';
3233
import {range} from '../utils/math_utils';
@@ -506,6 +507,9 @@ export class LayersModel extends Container implements tfc.InferenceModel {
506507
// implicit "knowledge" of the outputs it depends on.
507508
metricsTensors: Array<[LossOrMetricFn, number]>;
508509

510+
// User defind metadata (if any).
511+
private userDefinedMetadata: {};
512+
509513
constructor(args: ContainerArgs) {
510514
super(args);
511515
this.isTraining = false;
@@ -837,8 +841,8 @@ export class LayersModel extends Container implements tfc.InferenceModel {
837841
// TODO(cais): Standardize `config.sampleWeights` as well.
838842
// Validate user data.
839843
const checkBatchAxis = true;
840-
const standardizedOuts = this.standardizeUserDataXY(
841-
x, y, checkBatchAxis, batchSize);
844+
const standardizedOuts =
845+
this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
842846
try {
843847
// TODO(cais): If uses `useLearningPhase`, set the corresponding element
844848
// of the input to 0.
@@ -1136,10 +1140,9 @@ export class LayersModel extends Container implements tfc.InferenceModel {
11361140
}
11371141

11381142
protected standardizeUserDataXY(
1139-
x: Tensor|Tensor[]|{[inputName: string]: Tensor},
1140-
y: Tensor|Tensor[]|{[inputName: string]: Tensor},
1141-
checkBatchAxis = true,
1142-
batchSize?: number): [Tensor[], Tensor[]] {
1143+
x: Tensor|Tensor[]|{[inputName: string]: Tensor},
1144+
y: Tensor|Tensor[]|{[inputName: string]: Tensor}, checkBatchAxis = true,
1145+
batchSize?: number): [Tensor[], Tensor[]] {
11431146
// TODO(cais): Add sampleWeight, classWeight
11441147
if (this.optimizer_ == null) {
11451148
throw new RuntimeError(
@@ -1348,8 +1351,9 @@ export class LayersModel extends Container implements tfc.InferenceModel {
13481351
} else {
13491352
const metric = this.metricsTensors[i][0];
13501353
const outputIndex = this.metricsTensors[i][1];
1351-
weightedMetric = tfc.mean(
1352-
metric(targets[outputIndex], outputs[outputIndex])) as Scalar;
1354+
weightedMetric =
1355+
tfc.mean(metric(targets[outputIndex], outputs[outputIndex])) as
1356+
Scalar;
13531357
}
13541358

13551359
tfc.keep(weightedMetric);
@@ -1547,7 +1551,7 @@ export class LayersModel extends Container implements tfc.InferenceModel {
15471551
* @returns A `NamedTensorMap` mapping original weight names (i.e.,
15481552
* non-uniqueified weight names) to their values.
15491553
*/
1550-
protected getNamedWeights(config?: io.SaveConfig): NamedTensor [] {
1554+
protected getNamedWeights(config?: io.SaveConfig): NamedTensor[] {
15511555
const namedWeights: NamedTensor[] = [];
15521556

15531557
const trainableOnly = config != null && config.trainableOnly;
@@ -1642,15 +1646,15 @@ export class LayersModel extends Container implements tfc.InferenceModel {
16421646
} else {
16431647
const outputNames = Object.keys(this.loss);
16441648
lossNames = {} as {[outputName: string]: LossIdentifier};
1645-
const losses = this.loss as {[outputName: string]: LossOrMetricFn|string};
1649+
const losses =
1650+
this.loss as {[outputName: string]: LossOrMetricFn | string};
16461651
for (const outputName of outputNames) {
16471652
if (typeof losses[outputName] === 'string') {
16481653
lossNames[outputName] =
16491654
toSnakeCase(losses[outputName] as string) as LossIdentifier;
16501655
} else {
16511656
throw new Error('Serialization of non-string loss is not supported.');
16521657
}
1653-
16541658
}
16551659
}
16561660
return lossNames;
@@ -1722,8 +1726,8 @@ export class LayersModel extends Container implements tfc.InferenceModel {
17221726
} else if (trainingConfig.metrics != null) {
17231727
metrics = {} as {[outputName: string]: MetricsIdentifier};
17241728
for (const key in trainingConfig.metrics) {
1725-
metrics[key] = toCamelCase(trainingConfig.metrics[key]) as
1726-
MetricsIdentifier;
1729+
metrics[key] =
1730+
toCamelCase(trainingConfig.metrics[key]) as MetricsIdentifier;
17271731
}
17281732
}
17291733

@@ -1856,9 +1860,44 @@ export class LayersModel extends Container implements tfc.InferenceModel {
18561860
[weightDataAndSpecs.data, optimizerWeightData]);
18571861
}
18581862

1863+
if (this.userDefinedMetadata != null) {
1864+
// Check serialized size of user-defined metadata.
1865+
const checkSize = true;
1866+
checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
1867+
modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
1868+
}
1869+
18591870
modelArtifacts.weightData = weightDataAndSpecs.data;
18601871
modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
18611872
return handlerOrURL.save(modelArtifacts);
18621873
}
1874+
1875+
/**
1876+
* Set user-defined metadata.
1877+
*
1878+
* The set metadata will be serialized together with the topology
1879+
* and weights of the model during `save()` calls.
1880+
*
1881+
* @param setUserDefinedMetadata
1882+
*/
1883+
setUserDefinedMetadata(userDefinedMetadata: {}): void {
1884+
checkUserDefinedMetadata(userDefinedMetadata, this.name);
1885+
this.userDefinedMetadata = userDefinedMetadata;
1886+
}
1887+
1888+
/**
1889+
* Get user-defined metadata.
1890+
*
1891+
* The metadata is supplied via one of the two routes:
1892+
* 1. By calling `setUserDefinedMetadata()`.
1893+
* 2. Loaded during model loading (if the model is constructed
1894+
* via `tf.loadLayersModel()`.)
1895+
*
1896+
* If no user-defined metadata is available from either of the
1897+
* two routes, this function will return `undefined`.
1898+
*/
1899+
getUserDefinedMetadata(): {} {
1900+
return this.userDefinedMetadata;
1901+
}
18631902
}
18641903
serialization.registerClass(LayersModel);

src/model_save_test.ts

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,7 @@ describeMathGPU('Save-load round trips', () => {
290290

291291
const getInitSpy = spyOn(initializers, 'getInitializer').and.callThrough();
292292
const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
293-
const modelPrime = await tfl.loadLayersModel(io.fromMemory(
294-
savedArtifacts.modelTopology, savedArtifacts.weightSpecs,
295-
savedArtifacts.weightData));
293+
const modelPrime = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
296294
const weightsPrime = modelPrime.getWeights();
297295
expect(weightsPrime.length).toEqual(weights.length);
298296
for (let i = 0; i < weights.length; ++i) {
@@ -321,9 +319,7 @@ describeMathGPU('Save-load round trips', () => {
321319

322320
const getInitSpy = spyOn(initializers, 'getInitializer').and.callThrough();
323321
const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
324-
const modelPrime = await tfl.loadLayersModel(io.fromMemory(
325-
savedArtifacts.modelTopology, savedArtifacts.weightSpecs,
326-
savedArtifacts.weightData));
322+
const modelPrime = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
327323
const weightsPrime = modelPrime.getWeights();
328324
expect(weightsPrime.length).toEqual(weights.length);
329325
for (let i = 0; i < weights.length; ++i) {
@@ -357,9 +353,7 @@ describeMathGPU('Save-load round trips', () => {
357353

358354
const getInitSpy = spyOn(initializers, 'getInitializer').and.callThrough();
359355
const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
360-
const modelPrime = await tfl.loadLayersModel(io.fromMemory(
361-
savedArtifacts.modelTopology, savedArtifacts.weightSpecs,
362-
savedArtifacts.weightData));
356+
const modelPrime = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
363357
const weightsPrime = modelPrime.getWeights();
364358
expect(weightsPrime.length).toEqual(weights.length);
365359
for (let i = 0; i < weights.length; ++i) {
@@ -395,9 +389,7 @@ describeMathGPU('Save-load round trips', () => {
395389

396390
const getInitSpy = spyOn(initializers, 'getInitializer').and.callThrough();
397391
const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
398-
const modelPrime = await tfl.loadLayersModel(io.fromMemory(
399-
savedArtifacts.modelTopology, savedArtifacts.weightSpecs,
400-
savedArtifacts.weightData));
392+
const modelPrime = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
401393
const weightsPrime = modelPrime.getWeights();
402394
expect(weightsPrime.length).toEqual(weights.length);
403395
for (let i = 0; i < weights.length; ++i) {
@@ -452,10 +444,7 @@ describeMathGPU('Save-load round trips', () => {
452444
const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
453445
const strict = false;
454446
const modelPrime = await tfl.loadLayersModel(
455-
io.fromMemory(
456-
savedArtifacts.modelTopology, savedArtifacts.weightSpecs,
457-
savedArtifacts.weightData),
458-
{strict});
447+
io.fromMemory(savedArtifacts), {strict});
459448
const weightsPrime = modelPrime.getWeights();
460449
expect(weightsPrime.length).toEqual(weights.length);
461450
expectTensorsClose(weightsPrime[0], weights[0]);
@@ -486,10 +475,7 @@ describeMathGPU('Save-load round trips', () => {
486475
const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
487476
const strict = false;
488477
const modelPrime = await tfl.loadLayersModel(
489-
io.fromMemory(
490-
savedArtifacts.modelTopology, savedArtifacts.weightSpecs,
491-
savedArtifacts.weightData),
492-
{strict});
478+
io.fromMemory(savedArtifacts), {strict});
493479
const weightsPrime = modelPrime.getWeights();
494480
expect(weightsPrime.length).toEqual(weights.length);
495481
expectTensorsClose(weightsPrime[0], weights[0]);

src/models.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,9 @@ export async function loadLayersModelFromIOHandler(
306306
if (trainingConfig != null) {
307307
model.loadTrainingConfig(trainingConfig);
308308
}
309+
if (artifacts.userDefinedMetadata != null) {
310+
model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
311+
}
309312

310313
// If weightData is present, load the weights into the model.
311314
if (artifacts.weightData != null) {

src/models_test.ts

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,6 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
12371237
return null;
12381238
}), {includeOptimizer: true});
12391239

1240-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
12411240
const trainingConfig = savedArtifacts.trainingConfig;
12421241
expect(trainingConfig['loss']).toEqual('mean_squared_error');
12431242

@@ -1254,8 +1253,7 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
12541253
expect(weightData.byteLength).toEqual(4 * 8 + 4 * 1 + 4);
12551254

12561255
// Load the model back, with the optimizer.
1257-
const model2 = await tfl.loadLayersModel(
1258-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1256+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
12591257
expect(model2.optimizer.getConfig()['learningRate']).toEqual(learningRate);
12601258

12611259
const optimizer1Weights = await model1.optimizer.getWeights();
@@ -1295,7 +1293,6 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
12951293
return null;
12961294
}), {includeOptimizer: true});
12971295

1298-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
12991296
const trainingConfig = savedArtifacts.trainingConfig;
13001297
expect(trainingConfig['loss']).toEqual('mean_squared_error');
13011298

@@ -1317,8 +1314,7 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
13171314
expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
13181315

13191316
// Load the model back, with the optimizer.
1320-
const model2 = await tfl.loadLayersModel(
1321-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1317+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
13221318
expect(model2.optimizer.getConfig()['learningRate']).toEqual(learningRate);
13231319
expect(model2.optimizer.getConfig()['decay']).toEqual(decay);
13241320

@@ -1358,7 +1354,6 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
13581354
return null;
13591355
}), {includeOptimizer: true});
13601356

1361-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
13621357
const trainingConfig = savedArtifacts.trainingConfig;
13631358
expect(trainingConfig['loss']).toEqual('mean_squared_error');
13641359

@@ -1380,8 +1375,7 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
13801375
expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
13811376

13821377
// Load the model back, with the optimizer.
1383-
const model2 = await tfl.loadLayersModel(
1384-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1378+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
13851379
expect(model2.optimizer.getConfig()['learningRate']).toEqual(learningRate);
13861380

13871381
const optimizer1Weights = await model1.optimizer.getWeights();
@@ -1421,7 +1415,6 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
14211415
return null;
14221416
}), {includeOptimizer: true});
14231417

1424-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
14251418
const trainingConfig = savedArtifacts.trainingConfig;
14261419
expect(trainingConfig['loss']).toEqual('mean_squared_error');
14271420

@@ -1441,8 +1434,7 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
14411434
expect(weightData.byteLength).toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2);
14421435

14431436
// Load the model back, with the optimizer.
1444-
const model2 = await tfl.loadLayersModel(
1445-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1437+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
14461438
expect(model2.optimizer.getConfig()['learningRate']).toEqual(learningRate);
14471439
expect(model2.optimizer.getConfig()['initialAccumulatorValue'])
14481440
.toEqual(initialAccumulatorValue);
@@ -1485,7 +1477,6 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
14851477
return null;
14861478
}), {includeOptimizer: true});
14871479

1488-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
14891480
const trainingConfig = savedArtifacts.trainingConfig;
14901481
expect(trainingConfig['loss']).toEqual('mean_squared_error');
14911482

@@ -1507,8 +1498,7 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
15071498
expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
15081499

15091500
// Load the model back, with the optimizer.
1510-
const model2 = await tfl.loadLayersModel(
1511-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1501+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
15121502
expect(model2.optimizer.getConfig()['learningRate']).toEqual(learningRate);
15131503
expect(model2.optimizer.getConfig()['beta1']).toEqual(beta1);
15141504
expect(model2.optimizer.getConfig()['beta2']).toEqual(beta2);
@@ -1547,7 +1537,6 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
15471537
return null;
15481538
}), {includeOptimizer: true});
15491539

1550-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
15511540
const trainingConfig = savedArtifacts.trainingConfig;
15521541
expect(trainingConfig['loss']).toEqual('mean_squared_error');
15531542

@@ -1569,8 +1558,7 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
15691558
expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3);
15701559

15711560
// Load the model back, with the optimizer.
1572-
const model2 = await tfl.loadLayersModel(
1573-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1561+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
15741562
expect(model2.optimizer.getConfig()['learningRate']).toEqual(1e-3);
15751563

15761564
const optimizer1Weights = await model1.optimizer.getWeights();
@@ -1610,7 +1598,6 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
16101598
return null;
16111599
}), {includeOptimizer: true});
16121600

1613-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
16141601
const trainingConfig = savedArtifacts.trainingConfig;
16151602
expect(trainingConfig['loss']).toEqual('mean_squared_error');
16161603

@@ -1630,8 +1617,7 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
16301617
expect(weightData.byteLength).toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2);
16311618

16321619
// Load the model back, with the optimizer.
1633-
const model2 = await tfl.loadLayersModel(
1634-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1620+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
16351621
expect(model2.optimizer.getConfig()['learningRate']).toEqual(learningRate);
16361622

16371623
const optimizer1Weights = await model1.optimizer.getWeights();
@@ -1675,15 +1661,11 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
16751661
return null;
16761662
}), {includeOptimizer: true});
16771663

1678-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
16791664
const trainingConfig = savedArtifacts.trainingConfig;
16801665
expect(trainingConfig['loss']).toEqual('categorical_crossentropy');
16811666
expect(trainingConfig['metrics']).toEqual(['acc']);
16821667

1683-
const weightSpecs = savedArtifacts.weightSpecs;
1684-
const weightData = savedArtifacts.weightData;
1685-
const model2 = await tfl.loadLayersModel(
1686-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1668+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
16871669
h = await model2.fit(xs, ys, {epochs: 1});
16881670
expect(h.history.loss.length).toEqual(1);
16891671
expect(h.history.loss[0]).toBeCloseTo(1.086648);
@@ -1711,15 +1693,11 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
17111693
return null;
17121694
}), {includeOptimizer: true});
17131695

1714-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
17151696
const trainingConfig = savedArtifacts.trainingConfig;
17161697
expect(trainingConfig['loss']).toEqual('categorical_crossentropy');
17171698
expect(trainingConfig['metrics']).toEqual(['acc']);
17181699

1719-
const weightSpecs = savedArtifacts.weightSpecs;
1720-
const weightData = savedArtifacts.weightData;
1721-
const model2 = await tfl.loadLayersModel(
1722-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1700+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
17231701

17241702
const xs = ones([4, 8]);
17251703
const ys = tensor2d([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]]);
@@ -1769,16 +1747,12 @@ describeMathCPUAndGPU('Saving+loading model with optimizer', () => {
17691747
return null;
17701748
}), {includeOptimizer: true});
17711749

1772-
const modelTopology = savedArtifacts.modelTopology as ConfigDict;
17731750
const trainingConfig = savedArtifacts.trainingConfig;
17741751
expect(trainingConfig['loss']).toEqual(
17751752
['categorical_crossentropy', 'binary_crossentropy']);
17761753
expect(trainingConfig['metrics']).toEqual(['acc']);
17771754

1778-
const weightSpecs = savedArtifacts.weightSpecs;
1779-
const weightData = savedArtifacts.weightData;
1780-
const model2 = await tfl.loadLayersModel(
1781-
io.fromMemory(modelTopology, weightSpecs, weightData, trainingConfig));
1755+
const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts));
17821756

17831757
h = await model2.fit(xs, [ys1, ys2], {epochs: 1});
17841758
expect(h.history.loss.length).toEqual(1);

0 commit comments

Comments
 (0)