@@ -64,15 +64,20 @@ defmodule EXLA.MLIR.Value do
6464 % { type: rhs_type } = get_typespec ( rhs )
6565
6666 comparison_type =
67- if Nx.Type . float? ( lhs_type ) or Nx.Type . float? ( rhs_type ) do
68- attr_comparison_type ( :totalorder )
69- else
70- attr_comparison_type ( :notype )
67+ cond do
68+ Nx.Type . complex? ( lhs_type ) or Nx.Type . complex? ( rhs_type ) ->
69+ attr_comparison_type ( :float )
70+
71+ Nx.Type . float? ( lhs_type ) or Nx.Type . float? ( rhs_type ) ->
72+ attr_comparison_type ( :float )
73+
74+ true ->
75+ attr_comparison_type ( :notype )
7176 end
7277
7378 attributes = [
7479 comparison_direction: attr_comparison_direction ( direction ) ,
75- comparison_type : comparison_type
80+ compare_type : comparison_type
7681 ]
7782
7883 result_types = typespecs_to_mlir_types ( [ Typespec . to_type ( typespec , { :pred , 8 } ) ] )
@@ -929,7 +934,7 @@ defmodule EXLA.MLIR.Value do
929934 defp attr_comparison_direction ( value ) when value in [ :eq , :lt , :le , :gt , :ge , :ne ] ,
930935 do: attr_enum ( "stablehlo" , "comparison_direction" , value )
931936
932- defp attr_comparison_type ( value ) when value in [ :totalorder , :notype ] ,
937+ defp attr_comparison_type ( value ) when value in [ :float , : totalorder, :notype ] ,
933938 do: attr_enum ( "stablehlo" , "comparison_type" , value )
934939
935940 defp attr_precision ( value ) when value in [ :default , :high , :highest ] ,
0 commit comments