@@ -1555,15 +1555,32 @@ def call(self, inputs, inputs_positions=None):
15551555class TransformerEncoderBlock (nlp_modeling .layers .TransformerEncoderBlock ):
15561556 """TransformerEncoderBlock layer with stochastic depth and layerscale."""
15571557
1558- def __init__ (self ,
1559- * args ,
1560- stochastic_depth_drop_rate = 0.0 ,
1561- layer_scale_init_value = 0.0 ,
1562- ** kwargs ):
1563- """Initializes TransformerEncoderBlock."""
1558+ def __init__ (
1559+ self ,
1560+ * args ,
1561+ stochastic_depth_drop_rate = 0.0 ,
1562+ layer_scale_init_value = 0.0 ,
1563+ max_attention_inference_parallelism = None ,
1564+ ** kwargs
1565+ ):
1566+ """Initializes TransformerEncoderBlock.
1567+
1568+ Args:
1569+ *args: positional arguments passed to super().__init__.
1570+ stochastic_depth_drop_rate: the drop rate for the stochastic depth layer.
1571+ layer_scale_init_value:
1572+ max_attention_inference_parallelism: the number of examples to run in
1573+ parallel in the attention blocks during inference. Set this limit to
1574+ reduce the peak memory usage. If None, use vectorized operations to run
1575+ the whole batch in parallel.
1576+ **kwargs: keyword arguments passed to super().__init__.
1577+ """
15641578 super ().__init__ (* args , ** kwargs )
15651579 self ._stochastic_depth_drop_rate = stochastic_depth_drop_rate
15661580 self ._layer_scale_init_value = layer_scale_init_value
1581+ self ._max_attention_inference_parallelism = (
1582+ max_attention_inference_parallelism
1583+ )
15671584
15681585 def build (self , input_shape ):
15691586 if self ._stochastic_depth_drop_rate :
@@ -1582,10 +1599,25 @@ def build(self, input_shape):
15821599 self ._layer_scale_mlp = lambda x , * args , ** kwargs : tf .identity (x )
15831600 super ().build (input_shape )
15841601
1602+ if self ._max_attention_inference_parallelism is not None :
1603+ attention_layer_config = self ._attention_layer .get_config ()
1604+ self ._attention_layer = nn_layers .MultiHeadAttention .from_config ({
1605+ ** attention_layer_config ,
1606+ 'max_inference_parallelism' : (
1607+ self ._max_attention_inference_parallelism
1608+ ),
1609+ })
1610+
15851611 def get_config (self ):
1586- config = {'stochastic_depth_drop_rate' : self ._stochastic_depth_drop_rate }
1587- base_config = super ().get_config ()
1588- return dict (list (base_config .items ()) + list (config .items ()))
1612+ config = super ().get_config ()
1613+ config .update ({
1614+ 'stochastic_depth_drop_rate' : self ._stochastic_depth_drop_rate ,
1615+ 'layer_scale_init_value' : self ._layer_scale_init_value ,
1616+ 'max_attention_inference_parallelism' : (
1617+ self ._max_attention_inference_parallelism
1618+ ),
1619+ })
1620+ return config
15891621
15901622 def call (self , inputs , output_range = None , training = None ):
15911623 """Transformer self-attention encoder block call."""
@@ -1675,29 +1707,39 @@ def call(self, inputs, output_range=None, training=None):
16751707
16761708@tf .keras .utils .register_keras_serializable (package = 'Vision' )
16771709class TransformerScaffold (nlp_modeling .layers .TransformerScaffold ):
1678- """TransformerScaffold layer for vision applications.
1679-
1680- This layer is a subclass of NLP TransformerScaffold:
1710+ """TransformerScaffold layer for vision applications."""
16811711
1682- Attributes:
1683- stochastic_depth_drop_rate: Drop rate for the residual connections.
1684- return_attention_scores: Optionally return the attention output.
1685- ffn_has_residual_connection: Whether the feedforward network has internal
1686- residual connection and layer norm. If False, the residual connection and
1687- the layer norm op are called inside TransformerScaffold.
1688- """
1712+ def __init__ (
1713+ self ,
1714+ * args ,
1715+ stochastic_depth_drop_rate : float = 0.0 ,
1716+ return_attention_scores : bool = False ,
1717+ ffn_has_residual_connection : bool = False ,
1718+ max_attention_inference_parallelism : Optional [int ] = None ,
1719+ ** kwargs
1720+ ):
1721+ """Initializes TransformerEncoderBlock.
16891722
1690- def __init__ (self ,
1691- * args ,
1692- stochastic_depth_drop_rate : float = 0.0 ,
1693- return_attention_scores : bool = False ,
1694- ffn_has_residual_connection : bool = False ,
1695- ** kwargs ):
1696- """Initializes TransformerEncoderBlock."""
1723+ Args:
1724+ *args: positional arguments passed to super().__init__.
1725+ stochastic_depth_drop_rate: the drop rate for the stochastic depth layer.
1726+ return_attention_scores: whether to return the attention output.
1727+ ffn_has_residual_connection: whether the feedforward network has internal
1728+ residual connection and layer norm. If False, the residual connection
1729+ and the layer norm op are called inside TransformerScaffold.
1730+ max_attention_inference_parallelism: the number of examples to run in
1731+ parallel in the attention blocks during inference. Set this limit to
1732+ reduce the peak memory usage. If None, use vectorized operations to run
1733+ the whole batch in parallel.
1734+ **kwargs: keyword arguments passed to super().__init__.
1735+ """
16971736 super ().__init__ (* args , ** kwargs )
16981737 self ._stochastic_depth_drop_rate = stochastic_depth_drop_rate
16991738 self ._return_attention_scores = return_attention_scores
17001739 self ._ffn_has_residual_connection = ffn_has_residual_connection
1740+ self ._max_attention_inference_parallelism = (
1741+ max_attention_inference_parallelism
1742+ )
17011743
17021744 def build (self , input_shape : Union [tf .TensorShape , List [int ]]):
17031745 if self ._stochastic_depth_drop_rate :
@@ -1708,15 +1750,26 @@ def build(self, input_shape: Union[tf.TensorShape, List[int]]):
17081750
17091751 super ().build (input_shape )
17101752
1753+ if self ._max_attention_inference_parallelism is not None :
1754+ attention_layer_config = self ._attention_layer .get_config ()
1755+ self ._attention_layer = self ._attention_cls .from_config ({
1756+ ** attention_layer_config ,
1757+ 'max_inference_parallelism' : (
1758+ self ._max_attention_inference_parallelism
1759+ ),
1760+ })
1761+
17111762 def get_config (self ):
1712- config = {
1763+ config = super ().get_config ()
1764+ config .update ({
17131765 'stochastic_depth_drop_rate' : self ._stochastic_depth_drop_rate ,
17141766 'return_attention_scores' : self ._return_attention_scores ,
1715- 'ffn_has_residual_connection' : self ._ffn_has_residual_connection
1716- }
1717- base_config = super ().get_config ()
1718- base_config .update (config )
1719- return base_config
1767+ 'ffn_has_residual_connection' : self ._ffn_has_residual_connection ,
1768+ 'max_attention_inference_parallelism' : (
1769+ self ._max_attention_inference_parallelism
1770+ ),
1771+ })
1772+ return config
17201773
17211774 def call (
17221775 self ,
0 commit comments