@@ -112,22 +112,45 @@ for (jlop, hloop, hlocomp) in (
112112 (:(Base.:(<= )), :compare , " LE" ),
113113 (:(Base.:(< )), :compare , " LT" ),
114114)
115- @eval function $ (jlop)(
116- @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
117- ) where {T}
118- return TracedRNumber {Bool} (
119- (),
120- MLIR. IR. result (
121- MLIR. Dialects. stablehlo.$ (hloop)(
122- lhs. mlir_data,
123- rhs. mlir_data;
124- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
125- MLIR. IR. context (), $ hlocomp
115+ @eval begin
116+ function $ (jlop)(
117+ @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
118+ ) where {T}
119+ return TracedRNumber {Bool} (
120+ (),
121+ MLIR. IR. result (
122+ MLIR. Dialects. stablehlo.$ (hloop)(
123+ lhs. mlir_data,
124+ rhs. mlir_data;
125+ comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
126+ MLIR. IR. context (), $ hlocomp
127+ ),
126128 ),
129+ 1 ,
127130 ),
128- 1 ,
129- ),
130- )
131+ )
132+ end
133+
134+ function $ (jlop)(
135+ @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs)
136+ ) where {T}
137+ return $ (jlop)(lhs, promote_to (lhs, rhs))
138+ end
139+
140+ function $ (jlop)(
141+ @nospecialize (lhs), @nospecialize (rhs:: TracedRNumber{T} )
142+ ) where {T}
143+ return $ (jlop)(promote_to (rhs, lhs), rhs)
144+ end
145+
146+ function $ (jlop)(
147+ @nospecialize (lhs:: TracedRNumber{T1} ), @nospecialize (rhs:: TracedRNumber{T2} )
148+ ) where {T1,T2}
149+ commonTy = TracedRNumber{Base. promote_type (T1, T2)}
150+ lhs = promote_to (commonTy, lhs)
151+ rhs = promote_to (commonTy, rhs)
152+ return $ (jlop)(lhs, rhs)
153+ end
131154 end
132155end
133156
@@ -169,6 +192,9 @@ for (jlop, hloop) in (
169192 end
170193end
171194
195+ # XXX : Enzyme-MLIR doesn't have `abs` adjoint defined
196+ Base. abs2 (x:: TracedRNumber{<:Real} ) = x^ 2
197+
172198struct TypeCast{T<: ReactantPrimitives } <: Function end
173199
174200(:: TypeCast{T} )(x:: TracedRNumber{T2} ) where {T,T2} = promote_to (TracedRNumber{T}, x)
0 commit comments