Skip to content

Commit 8538149

Browse files
G4Gcopybara-github
authored andcommitted
Adds typing information to the module util.shape.
PiperOrigin-RevId: 425878021
1 parent 70d95be commit 8538149

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

tensorflow_graphics/util/shape.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020
import itertools
21+
from typing import Any, List, Optional, Tuple, Union
2122

2223
import numpy as np
2324
import six
@@ -26,7 +27,8 @@
2627
import tensorflow as tf
2728

2829

29-
def _broadcast_shape_helper(shape_x, shape_y):
30+
def _broadcast_shape_helper(shape_x: tf.TensorShape,
31+
shape_y: tf.TensorShape) -> Optional[List[Any]]:
3032
"""Helper function for is_broadcast_compatible and broadcast_shape.
3133
3234
Args:
@@ -74,7 +76,8 @@ def _broadcast_shape_helper(shape_x, shape_y):
7476
return return_dims
7577

7678

77-
def is_broadcast_compatible(shape_x, shape_y):
79+
def is_broadcast_compatible(shape_x: tf.TensorShape,
80+
shape_y: tf.TensorShape) -> bool:
7881
"""Returns True if `shape_x` and `shape_y` are broadcast compatible.
7982
8083
Args:
@@ -90,7 +93,8 @@ def is_broadcast_compatible(shape_x, shape_y):
9093
return _broadcast_shape_helper(shape_x, shape_y) is not None
9194

9295

93-
def get_broadcasted_shape(shape_x, shape_y):
96+
def get_broadcasted_shape(shape_x: tf.TensorShape,
97+
shape_y: tf.TensorShape) -> Optional[List[Any]]:
9498
"""Returns the common shape for broadcast compatible shapes.
9599
96100
Args:
@@ -135,14 +139,15 @@ def _get_dim(tensor, axis):
135139
return tf.compat.dimension_value(tensor.shape[axis])
136140

137141

138-
def check_static(tensor,
139-
has_rank=None,
140-
has_rank_greater_than=None,
141-
has_rank_less_than=None,
142+
def check_static(tensor: tf.Tensor,
143+
has_rank: Optional[int] = None,
144+
has_rank_greater_than: Optional[int] = None,
145+
has_rank_less_than: Optional[int] = None,
142146
has_dim_equals=None,
143147
has_dim_greater_than=None,
144148
has_dim_less_than=None,
145-
tensor_name='tensor'):
149+
tensor_name: str = 'tensor') -> None:
150+
# TODO(cengizo): Typing for has_dim_equals, has_dim_greater(less)_than.
146151
"""Checks static shapes for rank and dimension constraints.
147152
148153
This function can be used to check a tensor's shape for multiple rank and
@@ -276,11 +281,12 @@ def _raise_error(tensor_names, batch_shapes):
276281
'Not all batch dimensions are identical: {}'.format(formatted_list))
277282

278283

279-
def compare_batch_dimensions(tensors,
280-
last_axes,
281-
broadcast_compatible,
282-
initial_axes=0,
283-
tensor_names=None):
284+
def compare_batch_dimensions(
285+
tensors: Union[List[tf.Tensor], Tuple[tf.Tensor]],
286+
last_axes: Union[int, List[int], Tuple[int]],
287+
broadcast_compatible: bool,
288+
initial_axes: Union[int, List[int], Tuple[int]] = 0,
289+
tensor_names: Optional[Union[List[str], Tuple[str]]] = None) -> None:
284290
"""Compares batch dimensions for tensors with static shapes.
285291
286292
Args:
@@ -347,7 +353,10 @@ def compare_batch_dimensions(tensors,
347353
]))
348354

349355

350-
def compare_dimensions(tensors, axes, tensor_names=None):
356+
def compare_dimensions(
357+
tensors: Union[List[tf.Tensor], Tuple[tf.Tensor]],
358+
axes: Union[int, List[int], Tuple[int]],
359+
tensor_names: Optional[Union[List[str], Tuple[str]]] = None) -> None:
351360
"""Compares dimensions of tensors with static or dynamic shapes.
352361
353362
Args:
@@ -376,15 +385,19 @@ def compare_dimensions(tensors, axes, tensor_names=None):
376385
list(tensor_names), list(axes), list(dimensions)))
377386

378387

379-
def is_static(tensor_shape):
388+
def is_static(
389+
tensor_shape: Union[List[Any], Tuple[Any], tf.TensorShape]) -> bool:
380390
"""Checks if the given tensor shape is static."""
381391
if isinstance(tensor_shape, (list, tuple)):
382392
return None not in tensor_shape
383393
else:
384394
return None not in tensor_shape.as_list()
385395

386396

387-
def add_batch_dimensions(tensor, tensor_name, batch_shape, last_axis=None):
397+
def add_batch_dimensions(tensor: tf.Tensor,
398+
tensor_name: str,
399+
batch_shape: List[int],
400+
last_axis: Optional[int] = None) -> tf.Tensor:
388401
"""Broadcasts tensor to match batch dimensions.
389402
390403
It will either broadcast to all provided batch dimensions, therefore

0 commit comments

Comments
 (0)