@@ -91,7 +91,7 @@ class ExportArchive(BackendExportArchive):
9191
9292 **Note on resource tracking:**
9393
94- `ExportArchive` is able to automatically track all `tf .Variables` used
94+ `ExportArchive` is able to automatically track all `keras .Variables` used
9595 by its endpoints, so most of the time calling `.track(model)`
9696 is not strictly required. However, if your model uses lookup layers such
9797 as `IntegerLookup`, `StringLookup`, or `TextVectorization`,
@@ -104,9 +104,10 @@ class ExportArchive(BackendExportArchive):
104104
105105 def __init__ (self ):
106106 super ().__init__ ()
107- if backend .backend () not in ("tensorflow" , "jax" ):
107+ if backend .backend () not in ("tensorflow" , "jax" , "torch" ):
108108 raise NotImplementedError (
109- "The export API is only compatible with JAX and TF backends."
109+ "`ExportArchive` is only compatible with TensorFlow, JAX and "
110+ "Torch backends."
110111 )
111112
112113 self ._endpoint_names = []
@@ -141,8 +142,8 @@ def track(self, resource):
141142 (`TextVectorization`, `IntegerLookup`, `StringLookup`)
142143 are automatically tracked in `add_endpoint()`.
143144
144- Arguments :
145- resource: A trackable TensorFlow resource.
145+ Args :
146+ resource: A trackable Keras resource, such as a layer or model .
146147 """
147148 if isinstance (resource , layers .Layer ) and not resource .built :
148149 raise ValueError (
@@ -334,12 +335,78 @@ def serving_fn(x):
334335 self ._endpoint_names .append (name )
335336 return decorated_fn
336337
338+ def track_and_add_endpoint (self , name , resource , input_signature , ** kwargs ):
339+ """Track the variables and register a new serving endpoint.
340+
341+ This function combines the functionality of `track` and `add_endpoint`.
342+ It tracks the variables of the `resource` (either a layer or a model)
343+ and registers a serving endpoint using `resource.__call__`.
344+
345+ Args:
346+ name: `str`. The name of the endpoint.
347+ resource: A trackable Keras resource, such as a layer or model.
348+ input_signature: Optional. Specifies the shape and dtype of `fn`.
349+ Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
350+ `backend.KerasTensor`, or backend tensor (see below for an
351+ example showing a `Functional` model with 2 input arguments). If
352+ not provided, `fn` must be a `tf.function` that has been called
353+ at least once. Defaults to `None`.
354+ **kwargs: Additional keyword arguments:
355+ - Specific to the JAX backend:
356+ - `is_static`: Optional `bool`. Indicates whether `fn` is
357+ static. Set to `False` if `fn` involves state updates
358+ (e.g., RNG seeds).
359+ - `jax2tf_kwargs`: Optional `dict`. Arguments for
360+ `jax2tf.convert`. See [`jax2tf.convert`](
361+ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
362+ If `native_serialization` and `polymorphic_shapes` are
363+ not provided, they are automatically computed.
364+
365+ """
366+ if name in self ._endpoint_names :
367+ raise ValueError (f"Endpoint name '{ name } ' is already taken." )
368+ if not isinstance (resource , layers .Layer ):
369+ raise ValueError (
370+ "Invalid resource type. Expected an instance of a Keras "
371+ "`Layer` or `Model`. "
372+ f"Received: resource={ resource } (of type { type (resource )} )"
373+ )
374+ if not resource .built :
375+ raise ValueError (
376+ "The layer provided has not yet been built. "
377+ "It must be built before export."
378+ )
379+ if backend .backend () != "jax" :
380+ if "jax2tf_kwargs" in kwargs or "is_static" in kwargs :
381+ raise ValueError (
382+ "'jax2tf_kwargs' and 'is_static' are only supported with "
383+ f"the jax backend. Current backend: { backend .backend ()} "
384+ )
385+
386+ input_signature = tree .map_structure (_make_tensor_spec , input_signature )
387+
388+ if not hasattr (BackendExportArchive , "track_and_add_endpoint" ):
389+ # Default behavior.
390+ self .track (resource )
391+ return self .add_endpoint (
392+ name , resource .__call__ , input_signature , ** kwargs
393+ )
394+ else :
395+ # Special case for the torch backend.
396+ decorated_fn = BackendExportArchive .track_and_add_endpoint (
397+ self , name , resource , input_signature , ** kwargs
398+ )
399+ self ._endpoint_signatures [name ] = input_signature
400+ setattr (self ._tf_trackable , name , decorated_fn )
401+ self ._endpoint_names .append (name )
402+ return decorated_fn
403+
337404 def add_variable_collection (self , name , variables ):
338405 """Register a set of variables to be retrieved after reloading.
339406
340407 Arguments:
341408 name: The string name for the collection.
342- variables: A tuple/list/set of `tf .Variable` instances.
409+ variables: A tuple/list/set of `keras .Variable` instances.
343410
344411 Example:
345412
@@ -496,9 +563,6 @@ def export_saved_model(
496563):
497564 """Export the model as a TensorFlow SavedModel artifact for inference.
498565
499- **Note:** This feature is currently supported only with TensorFlow and
500- JAX backends.
501-
502566 This method lets you export a model to a lightweight SavedModel artifact
503567 that contains the model's forward pass only (its `call()` method)
504568 and can be served via e.g. TensorFlow Serving. The forward pass is
@@ -527,6 +591,14 @@ def export_saved_model(
527591 If `native_serialization` and `polymorphic_shapes` are not
528592 provided, they are automatically computed.
529593
594+ **Note:** This feature is currently supported only with TensorFlow, JAX and
595+ Torch backends. Support for the Torch backend is experimental.
596+
597+ **Note:** The dynamic shape feature is not yet supported with Torch
598+ backend. As a result, you must fully define the shapes of the inputs using
599+ `input_signature`. If `input_signature` is not provided, all instances of
600+ `None` (such as the batch size) will be replaced with `1`.
601+
530602 Example:
531603
532604 ```python
@@ -543,28 +615,29 @@ def export_saved_model(
543615 `export()` method relies on `ExportArchive` internally.
544616 """
545617 export_archive = ExportArchive ()
546- export_archive .track (model )
547- if isinstance (model , (Functional , Sequential )):
548- if input_signature is None :
618+ if input_signature is None :
619+ if not model .built :
620+ raise ValueError (
621+ "The layer provided has not yet been built. "
622+ "It must be built before export."
623+ )
624+ if isinstance (model , (Functional , Sequential )):
549625 input_signature = tree .map_structure (
550626 _make_tensor_spec , model .inputs
551627 )
552- if isinstance (input_signature , list ) and len (input_signature ) > 1 :
553- input_signature = [input_signature ]
554- export_archive .add_endpoint (
555- "serve" , model .__call__ , input_signature , ** kwargs
556- )
557- else :
558- if input_signature is None :
628+ if isinstance (input_signature , list ) and len (input_signature ) > 1 :
629+ input_signature = [input_signature ]
630+ else :
559631 input_signature = _get_input_signature (model )
560- if not input_signature or not model ._called :
561- raise ValueError (
562- "The model provided has never called. "
563- "It must be called at least once before export."
564- )
565- export_archive .add_endpoint (
566- "serve" , model .__call__ , input_signature , ** kwargs
567- )
632+ if not input_signature or not model ._called :
633+ raise ValueError (
634+ "The model provided has never called. "
635+ "It must be called at least once before export."
636+ )
637+
638+ export_archive .track_and_add_endpoint (
639+ "serve" , model , input_signature , ** kwargs
640+ )
568641 export_archive .write_out (filepath , verbose = verbose )
569642
570643
0 commit comments