2525from __future__ import division
2626from __future__ import print_function
2727
28+ from typing import Optional
29+
2830import numpy as np
2931import tensorflow as tf
3032
3133from tensorflow_graphics .util import asserts
34+ from tensorflow_graphics .util import type_alias
3235
3336
34- def nonzero_sign (x , name = 'nonzero_sign' ):
37+ def nonzero_sign (
38+ x : type_alias .TensorLike ,
39+ name : str = 'nonzero_sign' ) -> tf .Tensor :
3540 """Returns the sign of x with sign(0) defined as 1 instead of 0."""
3641 with tf .name_scope (name ):
3742 x = tf .convert_to_tensor (value = x )
@@ -40,7 +45,11 @@ def nonzero_sign(x, name='nonzero_sign'):
4045 return tf .where (tf .greater_equal (x , 0.0 ), one , - one )
4146
4247
43- def safe_cospx_div_cosx (theta , factor , eps = None , name = 'safe_cospx_div_cosx' ):
48+ def safe_cospx_div_cosx (
49+ theta : type_alias .TensorLike ,
50+ factor : type_alias .TensorLike ,
51+ eps : Optional [type_alias .Float ] = None ,
52+ name : str = 'safe_cospx_div_cosx' ) -> tf .Tensor :
4453 """Calculates cos(factor * theta)/cos(theta) safely.
4554
4655 The term `cos(factor * theta)/cos(theta)` has periodic edge cases with
@@ -84,12 +93,13 @@ def safe_cospx_div_cosx(theta, factor, eps=None, name='safe_cospx_div_cosx'):
8493 return asserts .assert_no_infs_or_nans (div )
8594
8695
87- def safe_shrink (vector ,
88- minval = None ,
89- maxval = None ,
90- open_bounds = False ,
91- eps = None ,
92- name = 'safe_shrink' ):
96+ def safe_shrink (
97+ vector : type_alias .TensorLike ,
98+ minval : Optional [type_alias .TensorLike ] = None ,
99+ maxval : Optional [type_alias .TensorLike ] = None ,
100+ open_bounds : bool = False ,
101+ eps : Optional [type_alias .Float ] = None ,
102+ name : str = 'safe_shrink' ) -> tf .Tensor :
93103 """Shrinks vector by (1.0 - eps) based on its dtype.
94104
95105 This function shrinks the input vector by a very small amount to ensure that
@@ -141,7 +151,11 @@ def safe_shrink(vector,
141151 return vector
142152
143153
144- def safe_signed_div (a , b , eps = None , name = 'safe_signed_div' ):
154+ def safe_signed_div (
155+ a : type_alias .TensorLike ,
156+ b : type_alias .TensorLike ,
157+ eps : Optional [type_alias .Float ] = None ,
158+ name : str = 'safe_signed_div' ) -> tf .Tensor :
145159 """Calculates a/b safely.
146160
147161 If the tf-graphics debug flag is set to `True`, this function adds assertions
@@ -177,7 +191,11 @@ def safe_signed_div(a, b, eps=None, name='safe_signed_div'):
177191 return asserts .assert_no_infs_or_nans (a / (b + nonzero_sign (b ) * eps ))
178192
179193
180- def safe_sinpx_div_sinx (theta , factor , eps = None , name = 'safe_sinpx_div_sinx' ):
194+ def safe_sinpx_div_sinx (
195+ theta : type_alias .TensorLike ,
196+ factor : type_alias .TensorLike ,
197+ eps : Optional [type_alias .Float ] = None ,
198+ name : str = 'safe_sinpx_div_sinx' ) -> tf .Tensor :
181199 """Calculates sin(factor * theta)/sin(theta) safely.
182200
183201 The term `sin(factor * theta)/sin(theta)` appears when calculating spherical
@@ -221,7 +239,11 @@ def safe_sinpx_div_sinx(theta, factor, eps=None, name='safe_sinpx_div_sinx'):
221239 return asserts .assert_no_infs_or_nans (div )
222240
223241
224- def safe_unsigned_div (a , b , eps = None , name = 'safe_unsigned_div' ):
242+ def safe_unsigned_div (
243+ a : type_alias .TensorLike ,
244+ b : type_alias .TensorLike ,
245+ eps : Optional [type_alias .Float ] = None ,
246+ name : str = 'safe_unsigned_div' ) -> tf .Tensor :
225247 """Calculates a/b with b >= 0 safely.
226248
227249 If the tfg debug flag TFG_ADD_ASSERTS_TO_GRAPH defined in tfg_flags.py
0 commit comments