@@ -1499,64 +1499,23 @@ defmodule EXLA.Defn.ExprTest do
14991499 end
15001500 end
15011501
1502- describe "map" do
1503- defn map_plus ( t ) , do: Nx . map ( t , fn x -> x + 1 end )
1504- defn map_equal ( t ) , do: Nx . map ( t , [ type: { :f , 64 } ] , fn x -> Nx . equal ( x , 1 ) end )
1505- defn map_exp ( t ) , do: Nx . map ( t , [ type: { :f , 64 } ] , fn x -> Nx . exp ( x ) end )
1506-
1507- @ tag :unsupported_64_bit_op
1508- test "maps a function over the tensor" do
1509- assert_equal ( map_plus ( Nx . tensor ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) ) , Nx . tensor ( [ [ 2 , 3 , 4 ] , [ 5 , 6 , 7 ] ] ) )
1510- end
1511-
1512- @ tag :unsupported_64_bit_op
1513- test "maps a function with an output type" do
1514- assert_equal (
1515- map_equal ( Nx . tensor ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) ) ,
1516- Nx . tensor ( [ [ 1.0 , 0.0 , 0.0 ] , [ 0.0 , 0.0 , 0.0 ] ] , type: { :f , 64 } )
1517- )
1518-
1519- assert_equal (
1520- map_exp ( Nx . tensor ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) ) ,
1521- Nx . tensor (
1522- [
1523- [ 2.718281828459045 , 7.38905609893065 , 20.085536923187668 ] ,
1524- [ 54.598150033144236 , 148.4131591025766 , 403.4287934927351 ]
1525- ] ,
1526- type: { :f , 64 }
1527- )
1528- )
1529- end
1530-
1531- defn map_conditional ( t ) , do: Nx . map ( t , fn x -> if x > 0 , do: x , else: - x end )
1532-
1533- @ tag :conditional_inside_map_reduce
1534- @ tag :unsupported_64_bit_op
1535- test "maps a function with conditional" do
1536- assert_equal (
1537- map_conditional ( Nx . tensor ( [ - 2 , - 1 , 0 , 1 , 2 ] ) ) ,
1538- Nx . tensor ( [ 2 , 1 , 0 , 1 , 2 ] )
1539- )
1540- end
1541-
1542- defn while_inside_if ( pred , x ) do
1543- if pred do
1544- { x , _ } =
1545- while { x , i = 0 } , i < 10 do
1546- { x , i + 1 }
1547- end
1502+ defn while_inside_if ( pred , x ) do
1503+ if pred do
1504+ { x , _ } =
1505+ while { x , i = 0 } , i < 10 do
1506+ { x , i + 1 }
1507+ end
15481508
1549- x
1550- else
1551- x
1552- end
1509+ x
1510+ else
1511+ x
15531512 end
1513+ end
15541514
1555- test "while inside if" do
1556- assert % { a: a , b: b } = while_inside_if ( 1 , % { a: 1 , b: 2.0 } )
1557- assert_all_close ( a , 1 )
1558- assert_all_close ( b , 2.0 )
1559- end
1515+ test "while inside if" do
1516+ assert % { a: a , b: b } = while_inside_if ( 1 , % { a: 1 , b: 2.0 } )
1517+ assert_all_close ( a , 1 )
1518+ assert_all_close ( b , 2.0 )
15601519 end
15611520
15621521 describe "reduce" do
0 commit comments