Skip to content

Commit 8699795

Browse files
G4Gcopybara-github
authored andcommitted
Adds typing information to the module util.asserts.
PiperOrigin-RevId: 420148339
1 parent 5a9cb4f commit 8699795

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

tensorflow_graphics/util/asserts.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,20 @@
1818
TFG_ADD_ASSERTS_TO_GRAPH is set to True.
1919
"""
2020

21+
from typing import Optional
2122
from absl import flags
2223
import numpy as np
2324
import tensorflow as tf
2425

2526
from tensorflow_graphics.util import tfg_flags
27+
from tensorflow_graphics.util import type_alias
2628

2729
FLAGS = flags.FLAGS
2830

2931

30-
def assert_no_infs_or_nans(tensor, name='assert_no_infs_or_nans'):
32+
def assert_no_infs_or_nans(
33+
tensor: type_alias.TensorLike,
34+
name: str = 'assert_no_infs_or_nans') -> tf.Tensor:
3135
"""Checks a tensor for NaN and Inf values.
3236
3337
Note:
@@ -56,7 +60,11 @@ def assert_no_infs_or_nans(tensor, name='assert_no_infs_or_nans'):
5660
return tf.identity(tensor)
5761

5862

59-
def assert_all_above(vector, minval, open_bound=False, name='assert_all_above'):
63+
def assert_all_above(
64+
vector: type_alias.TensorLike,
65+
minval: type_alias.TensorLike,
66+
open_bound: bool = False,
67+
name: str = 'assert_all_above') -> tf.Tensor:
6068
"""Checks whether all values of vector are above minval.
6169
6270
Note:
@@ -91,7 +99,11 @@ def assert_all_above(vector, minval, open_bound=False, name='assert_all_above'):
9199
return tf.identity(vector)
92100

93101

94-
def assert_all_below(vector, maxval, open_bound=False, name='assert_all_below'):
102+
def assert_all_below(
103+
vector: type_alias.TensorLike,
104+
maxval: type_alias.TensorLike,
105+
open_bound: bool = False,
106+
name: str = 'assert_all_below') -> tf.Tensor:
95107
"""Checks whether all values of vector are below maxval.
96108
97109
Note:
@@ -126,11 +138,12 @@ def assert_all_below(vector, maxval, open_bound=False, name='assert_all_below'):
126138
return tf.identity(vector)
127139

128140

129-
def assert_all_in_range(vector,
130-
minval,
131-
maxval,
132-
open_bounds=False,
133-
name='assert_all_in_range'):
141+
def assert_all_in_range(
142+
vector: type_alias.TensorLike,
143+
minval: type_alias.TensorLike,
144+
maxval: type_alias.TensorLike,
145+
open_bounds: bool = False,
146+
name: str = 'assert_all_in_range') -> tf.Tensor:
134147
"""Checks whether all values of vector are between minval and maxval.
135148
136149
This function checks if all the values in the given vector are in an interval
@@ -174,7 +187,10 @@ def assert_all_in_range(vector,
174187
return tf.identity(vector)
175188

176189

177-
def assert_nonzero_norm(vector, eps=None, name='assert_nonzero_norm'):
190+
def assert_nonzero_norm(
191+
vector: type_alias.TensorLike,
192+
eps: Optional[type_alias.Float] = None,
193+
name: str = 'assert_nonzero_norm') -> tf.Tensor:
178194
"""Checks whether vector/quaternion has non-zero norm in its last dimension.
179195
180196
This function checks whether all the norms of the vectors are greater than
@@ -213,11 +229,12 @@ def assert_nonzero_norm(vector, eps=None, name='assert_nonzero_norm'):
213229
return tf.identity(vector)
214230

215231

216-
def assert_normalized(vector,
217-
order='euclidean',
218-
axis=-1,
219-
eps=None,
220-
name='assert_normalized'):
232+
def assert_normalized(
233+
vector: type_alias.TensorLike,
234+
order: str = 'euclidean',
235+
axis: int = -1,
236+
eps: Optional[type_alias.Float] = None,
237+
name: str = 'assert_normalized') -> tf.Tensor:
221238
"""Checks whether vector/quaternion is normalized in its last dimension.
222239
223240
Note:
@@ -254,10 +271,10 @@ def assert_normalized(vector,
254271
return tf.identity(vector)
255272

256273

257-
def assert_at_least_k_non_zero_entries(tensor,
258-
k=1,
259-
name='assert_at_least_k_non_zero_entries'
260-
):
274+
def assert_at_least_k_non_zero_entries(
275+
tensor: type_alias.TensorLike,
276+
k: int = 1,
277+
name: str = 'assert_at_least_k_non_zero_entries') -> tf.Tensor:
261278
"""Checks if `tensor` has at least k non-zero entries in the last dimension.
262279
263280
Given a tensor with `M` dimensions in its last axis, this function checks
@@ -292,7 +309,9 @@ def assert_at_least_k_non_zero_entries(tensor,
292309
return tf.identity(tensor)
293310

294311

295-
def assert_binary(tensor, name='assert_binary'):
312+
def assert_binary(
313+
tensor: type_alias.TensorLike,
314+
name: str = 'assert_binary') -> tf.Tensor:
296315
"""Asserts that all the values in the tensor are zeros or ones.
297316
298317
Args:
@@ -320,7 +339,7 @@ def assert_binary(tensor, name='assert_binary'):
320339
return tf.identity(tensor)
321340

322341

323-
def select_eps_for_addition(dtype):
342+
def select_eps_for_addition(dtype: tf.DType) -> type_alias.Float:
324343
"""Returns 2 * machine epsilon based on `dtype`.
325344
326345
This function picks an epsilon slightly greater than the machine epsilon,
@@ -339,7 +358,7 @@ def select_eps_for_addition(dtype):
339358
return 2.0 * np.finfo(dtype.as_numpy_dtype).eps
340359

341360

342-
def select_eps_for_division(dtype):
361+
def select_eps_for_division(dtype: tf.DType) -> type_alias.Float:
343362
"""Selects default values for epsilon to make divisions safe based on dtype.
344363
345364
This function returns an epsilon slightly greater than the smallest positive

0 commit comments

Comments
 (0)