@@ -125,57 +125,71 @@ defmodule EXLA.MLIR.Value do
125125 end
126126 end
127127
128- def is_infinity ( % Value { function: func } = operand , typespec ) do
128+ def is_infinity ( % Value { function: func } = operand , out_typespec ) do
129129 % { type: type } = get_typespec ( operand )
130130
131- typespec = Typespec . to_type ( typespec , { :pred , 8 } )
131+ typespec = Typespec . to_type ( out_typespec , { :pred , 8 } )
132132
133- cond do
134- Nx.Type . complex? ( type ) ->
135- float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
136- real = real ( operand , float_typespec )
137- imag = imag ( operand , float_typespec )
138- is_inf_real = is_infinity ( real , typespec )
139- is_inf_imag = is_infinity ( imag , typespec )
140- bitwise_or ( is_inf_real , is_inf_imag , typespec )
141-
142- Nx.Type . integer? ( type ) ->
143- # Integers are never infinity. We use inequality to make sure
144- # the operand is still a part of the computation
145- not_equal ( operand , operand , typespec )
133+ result =
134+ cond do
135+ Nx.Type . complex? ( type ) ->
136+ float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
137+ real = real ( operand , float_typespec )
138+ imag = imag ( operand , float_typespec )
139+ is_inf_real = is_infinity ( real , typespec )
140+ is_inf_imag = is_infinity ( imag , typespec )
141+ bitwise_or ( is_inf_real , is_inf_imag , typespec )
142+
143+ Nx.Type . integer? ( type ) ->
144+ # Integers are never infinity. We use inequality to make sure
145+ # the operand is still a part of the computation
146+ not_equal ( operand , operand , typespec )
147+
148+ true ->
149+ result_types = typespecs_to_mlir_types ( [ typespec ] )
150+ op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
151+ end
146152
147- true ->
148- result_types = typespecs_to_mlir_types ( [ typespec ] )
149- op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
153+ if out_typespec . type == typespec . type do
154+ result
155+ else
156+ convert ( result , out_typespec )
150157 end
151158 end
152159
153- def is_nan ( % Value { function: func } = operand , typespec ) do
160+ def is_nan ( % Value { function: func } = operand , out_typespec ) do
154161 % { type: type } = get_typespec ( operand )
155162
156- typespec = Typespec . to_type ( typespec , { :pred , 8 } )
163+ typespec = Typespec . to_type ( out_typespec , { :pred , 8 } )
157164
158- cond do
159- Nx.Type . complex? ( type ) ->
160- float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
161- real = real ( operand , float_typespec )
162- imag = imag ( operand , float_typespec )
163- is_nan_real = is_nan ( real , typespec )
164- is_nan_imag = is_nan ( imag , typespec )
165- bitwise_or ( is_nan_real , is_nan_imag , typespec )
166-
167- Nx.Type . integer? ( type ) ->
168- # Integers are never nan. We use inequality to make sure
169- # the operand is still a part of the computation
170- not_equal ( operand , operand , typespec )
165+ result =
166+ cond do
167+ Nx.Type . complex? ( type ) ->
168+ float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
169+ real = real ( operand , float_typespec )
170+ imag = imag ( operand , float_typespec )
171+ is_nan_real = is_nan ( real , typespec )
172+ is_nan_imag = is_nan ( imag , typespec )
173+ bitwise_or ( is_nan_real , is_nan_imag , typespec )
174+
175+ Nx.Type . integer? ( type ) ->
176+ # Integers are never nan. We use inequality to make sure
177+ # the operand is still a part of the computation
178+ not_equal ( operand , operand , typespec )
179+
180+ true ->
181+ result_types = typespecs_to_mlir_types ( [ typespec ] )
182+ is_inf = op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
183+ is_finite = op ( func , "stablehlo.is_finite" , [ operand ] , result_types ) |> one! ( )
184+ is_not_inf = bitwise_not ( is_inf , typespec )
185+ is_not_finite = bitwise_not ( is_finite , typespec )
186+ bitwise_and ( is_not_inf , is_not_finite , typespec )
187+ end
171188
172- true ->
173- result_types = typespecs_to_mlir_types ( [ typespec ] )
174- is_inf = op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
175- is_finite = op ( func , "stablehlo.is_finite" , [ operand ] , result_types ) |> one! ( )
176- is_not_inf = bitwise_not ( is_inf , typespec )
177- is_not_finite = bitwise_not ( is_finite , typespec )
178- bitwise_and ( is_not_inf , is_not_finite , typespec )
189+ if out_typespec . type == typespec . type do
190+ result
191+ else
192+ convert ( result , out_typespec )
179193 end
180194 end
181195
@@ -706,6 +720,10 @@ defmodule EXLA.MLIR.Value do
706720 op ( func , "stablehlo.while" , initial , result_types , regions: regions )
707721 end
708722
723+ def func_return ( func , values ) when is_list ( values ) do
724+ op ( func , "func.return" , values , [ ] )
725+ end
726+
709727 def return ( func , values ) when is_list ( values ) do
710728 op ( func , "stablehlo.return" , values , [ ] )
711729 end
0 commit comments