Skip to content

Commit 70d95be

Browse files
G4Gcopybara-github
authored andcommitted
Adds typing information to the module util.safe_ops.
PiperOrigin-RevId: 424056349
1 parent 8699795 commit 70d95be

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
tensorflow_graphics/rendering/kernels/BUILD
2-
tensorflow_graphics/rendering/opengl/BUILD
1+
tensorflow_graphics/rendering/kernels/BUILD:
2+
tensorflow_graphics/rendering/opengl/BUILD:

tensorflow_graphics/util/safe_ops.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,18 @@
2525
from __future__ import division
2626
from __future__ import print_function
2727

28+
from typing import Optional
29+
2830
import numpy as np
2931
import tensorflow as tf
3032

3133
from tensorflow_graphics.util import asserts
34+
from tensorflow_graphics.util import type_alias
3235

3336

34-
def nonzero_sign(x, name='nonzero_sign'):
37+
def nonzero_sign(
38+
x: type_alias.TensorLike,
39+
name: str = 'nonzero_sign') -> tf.Tensor:
3540
"""Returns the sign of x with sign(0) defined as 1 instead of 0."""
3641
with tf.name_scope(name):
3742
x = tf.convert_to_tensor(value=x)
@@ -40,7 +45,11 @@ def nonzero_sign(x, name='nonzero_sign'):
4045
return tf.where(tf.greater_equal(x, 0.0), one, -one)
4146

4247

43-
def safe_cospx_div_cosx(theta, factor, eps=None, name='safe_cospx_div_cosx'):
48+
def safe_cospx_div_cosx(
49+
theta: type_alias.TensorLike,
50+
factor: type_alias.TensorLike,
51+
eps: Optional[type_alias.Float] = None,
52+
name: str = 'safe_cospx_div_cosx') -> tf.Tensor:
4453
"""Calculates cos(factor * theta)/cos(theta) safely.
4554
4655
The term `cos(factor * theta)/cos(theta)` has periodic edge cases with
@@ -84,12 +93,13 @@ def safe_cospx_div_cosx(theta, factor, eps=None, name='safe_cospx_div_cosx'):
8493
return asserts.assert_no_infs_or_nans(div)
8594

8695

87-
def safe_shrink(vector,
88-
minval=None,
89-
maxval=None,
90-
open_bounds=False,
91-
eps=None,
92-
name='safe_shrink'):
96+
def safe_shrink(
97+
vector: type_alias.TensorLike,
98+
minval: Optional[type_alias.TensorLike] = None,
99+
maxval: Optional[type_alias.TensorLike] = None,
100+
open_bounds: bool = False,
101+
eps: Optional[type_alias.Float] = None,
102+
name: str = 'safe_shrink') -> tf.Tensor:
93103
"""Shrinks vector by (1.0 - eps) based on its dtype.
94104
95105
This function shrinks the input vector by a very small amount to ensure that
@@ -141,7 +151,11 @@ def safe_shrink(vector,
141151
return vector
142152

143153

144-
def safe_signed_div(a, b, eps=None, name='safe_signed_div'):
154+
def safe_signed_div(
155+
a: type_alias.TensorLike,
156+
b: type_alias.TensorLike,
157+
eps: Optional[type_alias.Float] = None,
158+
name: str = 'safe_signed_div') -> tf.Tensor:
145159
"""Calculates a/b safely.
146160
147161
If the tf-graphics debug flag is set to `True`, this function adds assertions
@@ -177,7 +191,11 @@ def safe_signed_div(a, b, eps=None, name='safe_signed_div'):
177191
return asserts.assert_no_infs_or_nans(a / (b + nonzero_sign(b) * eps))
178192

179193

180-
def safe_sinpx_div_sinx(theta, factor, eps=None, name='safe_sinpx_div_sinx'):
194+
def safe_sinpx_div_sinx(
195+
theta: type_alias.TensorLike,
196+
factor: type_alias.TensorLike,
197+
eps: Optional[type_alias.Float] = None,
198+
name: str = 'safe_sinpx_div_sinx') -> tf.Tensor:
181199
"""Calculates sin(factor * theta)/sin(theta) safely.
182200
183201
The term `sin(factor * theta)/sin(theta)` appears when calculating spherical
@@ -221,7 +239,11 @@ def safe_sinpx_div_sinx(theta, factor, eps=None, name='safe_sinpx_div_sinx'):
221239
return asserts.assert_no_infs_or_nans(div)
222240

223241

224-
def safe_unsigned_div(a, b, eps=None, name='safe_unsigned_div'):
242+
def safe_unsigned_div(
243+
a: type_alias.TensorLike,
244+
b: type_alias.TensorLike,
245+
eps: Optional[type_alias.Float] = None,
246+
name: str = 'safe_unsigned_div') -> tf.Tensor:
225247
"""Calculates a/b with b >= 0 safely.
226248
227249
If the tfg debug flag TFG_ADD_ASSERTS_TO_GRAPH defined in tfg_flags.py

0 commit comments

Comments
 (0)