1818TFG_ADD_ASSERTS_TO_GRAPH is set to True.
1919"""
2020
21+ from typing import Optional
2122from absl import flags
2223import numpy as np
2324import tensorflow as tf
2425
2526from tensorflow_graphics .util import tfg_flags
27+ from tensorflow_graphics .util import type_alias
2628
2729FLAGS = flags .FLAGS
2830
2931
30- def assert_no_infs_or_nans (tensor , name = 'assert_no_infs_or_nans' ):
32+ def assert_no_infs_or_nans (
33+ tensor : type_alias .TensorLike ,
34+ name : str = 'assert_no_infs_or_nans' ) -> tf .Tensor :
3135 """Checks a tensor for NaN and Inf values.
3236
3337 Note:
@@ -56,7 +60,11 @@ def assert_no_infs_or_nans(tensor, name='assert_no_infs_or_nans'):
5660 return tf .identity (tensor )
5761
5862
59- def assert_all_above (vector , minval , open_bound = False , name = 'assert_all_above' ):
63+ def assert_all_above (
64+ vector : type_alias .TensorLike ,
65+ minval : type_alias .TensorLike ,
66+ open_bound : bool = False ,
67+ name : str = 'assert_all_above' ) -> tf .Tensor :
6068 """Checks whether all values of vector are above minval.
6169
6270 Note:
@@ -91,7 +99,11 @@ def assert_all_above(vector, minval, open_bound=False, name='assert_all_above'):
9199 return tf .identity (vector )
92100
93101
94- def assert_all_below (vector , maxval , open_bound = False , name = 'assert_all_below' ):
102+ def assert_all_below (
103+ vector : type_alias .TensorLike ,
104+ maxval : type_alias .TensorLike ,
105+ open_bound : bool = False ,
106+ name : str = 'assert_all_below' ) -> tf .Tensor :
95107 """Checks whether all values of vector are below maxval.
96108
97109 Note:
@@ -126,11 +138,12 @@ def assert_all_below(vector, maxval, open_bound=False, name='assert_all_below'):
126138 return tf .identity (vector )
127139
128140
129- def assert_all_in_range (vector ,
130- minval ,
131- maxval ,
132- open_bounds = False ,
133- name = 'assert_all_in_range' ):
141+ def assert_all_in_range (
142+ vector : type_alias .TensorLike ,
143+ minval : type_alias .TensorLike ,
144+ maxval : type_alias .TensorLike ,
145+ open_bounds : bool = False ,
146+ name : str = 'assert_all_in_range' ) -> tf .Tensor :
134147 """Checks whether all values of vector are between minval and maxval.
135148
136149 This function checks if all the values in the given vector are in an interval
@@ -174,7 +187,10 @@ def assert_all_in_range(vector,
174187 return tf .identity (vector )
175188
176189
177- def assert_nonzero_norm (vector , eps = None , name = 'assert_nonzero_norm' ):
190+ def assert_nonzero_norm (
191+ vector : type_alias .TensorLike ,
192+ eps : Optional [type_alias .Float ] = None ,
193+ name : str = 'assert_nonzero_norm' ) -> tf .Tensor :
178194 """Checks whether vector/quaternion has non-zero norm in its last dimension.
179195
180196 This function checks whether all the norms of the vectors are greater than
@@ -213,11 +229,12 @@ def assert_nonzero_norm(vector, eps=None, name='assert_nonzero_norm'):
213229 return tf .identity (vector )
214230
215231
216- def assert_normalized (vector ,
217- order = 'euclidean' ,
218- axis = - 1 ,
219- eps = None ,
220- name = 'assert_normalized' ):
232+ def assert_normalized (
233+ vector : type_alias .TensorLike ,
234+ order : str = 'euclidean' ,
235+ axis : int = - 1 ,
236+ eps : Optional [type_alias .Float ] = None ,
237+ name : str = 'assert_normalized' ) -> tf .Tensor :
221238 """Checks whether vector/quaternion is normalized in its last dimension.
222239
223240 Note:
@@ -254,10 +271,10 @@ def assert_normalized(vector,
254271 return tf .identity (vector )
255272
256273
257- def assert_at_least_k_non_zero_entries (tensor ,
258- k = 1 ,
259- name = 'assert_at_least_k_non_zero_entries'
260- ) :
274+ def assert_at_least_k_non_zero_entries (
275+ tensor : type_alias . TensorLike ,
276+ k : int = 1 ,
277+ name : str = 'assert_at_least_k_non_zero_entries' ) -> tf . Tensor :
261278 """Checks if `tensor` has at least k non-zero entries in the last dimension.
262279
263280 Given a tensor with `M` dimensions in its last axis, this function checks
@@ -292,7 +309,9 @@ def assert_at_least_k_non_zero_entries(tensor,
292309 return tf .identity (tensor )
293310
294311
295- def assert_binary (tensor , name = 'assert_binary' ):
312+ def assert_binary (
313+ tensor : type_alias .TensorLike ,
314+ name : str = 'assert_binary' ) -> tf .Tensor :
296315 """Asserts that all the values in the tensor are zeros or ones.
297316
298317 Args:
@@ -320,7 +339,7 @@ def assert_binary(tensor, name='assert_binary'):
320339 return tf .identity (tensor )
321340
322341
323- def select_eps_for_addition (dtype ) :
342+ def select_eps_for_addition (dtype : tf . DType ) -> type_alias . Float :
324343 """Returns 2 * machine epsilon based on `dtype`.
325344
326345 This function picks an epsilon slightly greater than the machine epsilon,
@@ -339,7 +358,7 @@ def select_eps_for_addition(dtype):
339358 return 2.0 * np .finfo (dtype .as_numpy_dtype ).eps
340359
341360
342- def select_eps_for_division (dtype ) :
361+ def select_eps_for_division (dtype : tf . DType ) -> type_alias . Float :
343362 """Selects default values for epsilon to make divisions safe based on dtype.
344363
345364 This function returns an epsilon slightly greater than the smallest positive
0 commit comments