1818from __future__ import print_function
1919
2020import itertools
21+ from typing import Any , List , Optional , Tuple , Union
2122
2223import numpy as np
2324import six
2627import tensorflow as tf
2728
2829
29- def _broadcast_shape_helper (shape_x , shape_y ):
30+ def _broadcast_shape_helper (shape_x : tf .TensorShape ,
31+ shape_y : tf .TensorShape ) -> Optional [List [Any ]]:
3032 """Helper function for is_broadcast_compatible and broadcast_shape.
3133
3234 Args:
@@ -74,7 +76,8 @@ def _broadcast_shape_helper(shape_x, shape_y):
7476 return return_dims
7577
7678
77- def is_broadcast_compatible (shape_x , shape_y ):
79+ def is_broadcast_compatible (shape_x : tf .TensorShape ,
80+ shape_y : tf .TensorShape ) -> bool :
7881 """Returns True if `shape_x` and `shape_y` are broadcast compatible.
7982
8083 Args:
@@ -90,7 +93,8 @@ def is_broadcast_compatible(shape_x, shape_y):
9093 return _broadcast_shape_helper (shape_x , shape_y ) is not None
9194
9295
93- def get_broadcasted_shape (shape_x , shape_y ):
96+ def get_broadcasted_shape (shape_x : tf .TensorShape ,
97+ shape_y : tf .TensorShape ) -> Optional [List [Any ]]:
9498 """Returns the common shape for broadcast compatible shapes.
9599
96100 Args:
@@ -135,14 +139,15 @@ def _get_dim(tensor, axis):
135139 return tf .compat .dimension_value (tensor .shape [axis ])
136140
137141
138- def check_static (tensor ,
139- has_rank = None ,
140- has_rank_greater_than = None ,
141- has_rank_less_than = None ,
142+ def check_static (tensor : tf . Tensor ,
143+ has_rank : Optional [ int ] = None ,
144+ has_rank_greater_than : Optional [ int ] = None ,
145+ has_rank_less_than : Optional [ int ] = None ,
142146 has_dim_equals = None ,
143147 has_dim_greater_than = None ,
144148 has_dim_less_than = None ,
145- tensor_name = 'tensor' ):
149+ tensor_name : str = 'tensor' ) -> None :
150+ # TODO(cengizo): Typing for has_dim_equals, has_dim_greater(less)_than.
146151 """Checks static shapes for rank and dimension constraints.
147152
148153 This function can be used to check a tensor's shape for multiple rank and
@@ -276,11 +281,12 @@ def _raise_error(tensor_names, batch_shapes):
276281 'Not all batch dimensions are identical: {}' .format (formatted_list ))
277282
278283
279- def compare_batch_dimensions (tensors ,
280- last_axes ,
281- broadcast_compatible ,
282- initial_axes = 0 ,
283- tensor_names = None ):
284+ def compare_batch_dimensions (
285+ tensors : Union [List [tf .Tensor ], Tuple [tf .Tensor ]],
286+ last_axes : Union [int , List [int ], Tuple [int ]],
287+ broadcast_compatible : bool ,
288+ initial_axes : Union [int , List [int ], Tuple [int ]] = 0 ,
289+ tensor_names : Optional [Union [List [str ], Tuple [str ]]] = None ) -> None :
284290 """Compares batch dimensions for tensors with static shapes.
285291
286292 Args:
@@ -347,7 +353,10 @@ def compare_batch_dimensions(tensors,
347353 ]))
348354
349355
350- def compare_dimensions (tensors , axes , tensor_names = None ):
356+ def compare_dimensions (
357+ tensors : Union [List [tf .Tensor ], Tuple [tf .Tensor ]],
358+ axes : Union [int , List [int ], Tuple [int ]],
359+ tensor_names : Optional [Union [List [str ], Tuple [str ]]] = None ) -> None :
351360 """Compares dimensions of tensors with static or dynamic shapes.
352361
353362 Args:
@@ -376,15 +385,19 @@ def compare_dimensions(tensors, axes, tensor_names=None):
376385 list (tensor_names ), list (axes ), list (dimensions )))
377386
378387
379- def is_static (tensor_shape ):
388+ def is_static (
389+ tensor_shape : Union [List [Any ], Tuple [Any ], tf .TensorShape ]) -> bool :
380390 """Checks if the given tensor shape is static."""
381391 if isinstance (tensor_shape , (list , tuple )):
382392 return None not in tensor_shape
383393 else :
384394 return None not in tensor_shape .as_list ()
385395
386396
387- def add_batch_dimensions (tensor , tensor_name , batch_shape , last_axis = None ):
397+ def add_batch_dimensions (tensor : tf .Tensor ,
398+ tensor_name : str ,
399+ batch_shape : List [int ],
400+ last_axis : Optional [int ] = None ) -> tf .Tensor :
388401 """Broadcasts tensor to match batch dimensions.
389402
390403 It will either broadcast to all provided batch dimensions, therefore
0 commit comments