2222import { IOHandler , ModelArtifacts , SaveResult , TrainingConfig , WeightsManifestEntry } from './types' ;
2323
2424class PassthroughLoader implements IOHandler {
25- constructor (
26- private readonly modelTopology ?: { } | ArrayBuffer ,
27- private readonly weightSpecs ?: WeightsManifestEntry [ ] ,
28- private readonly weightData ?: ArrayBuffer ,
29- private readonly trainingConfig ?: TrainingConfig ) { }
25+ constructor ( private readonly modelArtifacts ?: ModelArtifacts ) { }
3026
3127 async load ( ) : Promise < ModelArtifacts > {
32- let result = { } ;
33- if ( this . modelTopology != null ) {
34- result = { modelTopology : this . modelTopology , ...result } ;
35- }
36- if ( this . weightSpecs != null && this . weightSpecs . length > 0 ) {
37- result = { weightSpecs : this . weightSpecs , ...result } ;
38- }
39- if ( this . weightData != null && this . weightData . byteLength > 0 ) {
40- result = { weightData : this . weightData , ...result } ;
41- }
42- if ( this . trainingConfig != null ) {
43- result = { trainingConfig : this . trainingConfig , ...result } ;
44- }
45- return result ;
28+ return this . modelArtifacts ;
4629 }
4730}
4831
@@ -67,7 +50,7 @@ class PassthroughSaver implements IOHandler {
6750 * modelTopology, weightSpecs, weightData));
6851 * ```
6952 *
70- * @param modelTopology a object containing model topology (i.e., parsed from
53+ * @param modelArtifacts a object containing model topology (i.e., parsed from
7154 * the JSON format).
7255 * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
7356 * names, shapes, types, and quantization of the weight data.
@@ -78,13 +61,39 @@ class PassthroughSaver implements IOHandler {
7861 * @returns A passthrough `IOHandler` that simply loads the provided data.
7962 */
8063export function fromMemory (
81- modelTopology : { } , weightSpecs ?: WeightsManifestEntry [ ] ,
64+ modelArtifacts : { } | ModelArtifacts , weightSpecs ?: WeightsManifestEntry [ ] ,
8265 weightData ?: ArrayBuffer , trainingConfig ?: TrainingConfig ) : IOHandler {
83- // TODO(cais): The arguments should probably be consolidated into a single
84- // object, with proper deprecation process. Even though this function isn't
85- // documented, it is public and being used by some downstream libraries.
86- return new PassthroughLoader (
87- modelTopology , weightSpecs , weightData , trainingConfig ) ;
66+ if ( arguments . length === 1 ) {
67+ const isModelArtifacts =
68+ ( modelArtifacts as ModelArtifacts ) . modelTopology != null ||
69+ ( modelArtifacts as ModelArtifacts ) . weightSpecs != null ;
70+ if ( isModelArtifacts ) {
71+ return new PassthroughLoader ( modelArtifacts as ModelArtifacts ) ;
72+ } else {
73+ // Legacy support: with only modelTopology.
74+ // TODO(cais): Remove this deprecated API.
75+ console . warn (
76+ 'Please call tf.io.fromMemory() with only one argument. ' +
77+ 'The argument should be of type ModelArtifacts. ' +
78+ 'The multi-argument signature of tf.io.fromMemory() has been ' +
79+ 'deprecated and will be removed in a future release.' ) ;
80+ return new PassthroughLoader ( { modelTopology : modelArtifacts as { } } ) ;
81+ }
82+ } else {
83+ // Legacy support.
84+ // TODO(cais): Remove this deprecated API.
85+ console . warn (
86+ 'Please call tf.io.fromMemory() with only one argument. ' +
87+ 'The argument should be of type ModelArtifacts. ' +
88+ 'The multi-argument signature of tf.io.fromMemory() has been ' +
89+ 'deprecated and will be removed in a future release.' ) ;
90+ return new PassthroughLoader ( {
91+ modelTopology : modelArtifacts as { } ,
92+ weightSpecs,
93+ weightData,
94+ trainingConfig
95+ } ) ;
96+ }
8897}
8998
9099/**
0 commit comments