66from onnx_array_api .ext_test_case import ExtTestCase
77from onnx_array_api .ort .ort_optimizers import ort_optimized_model
88from onnx_array_api .ort .ort_profile import ort_profile , merge_ort_profile
9+ from onnxruntime .capi ._pybind_state import (
10+ OrtValue as C_OrtValue ,
11+ OrtDevice as C_OrtDevice ,
12+ )
913
1014
1115class TestOrtProfile (ExtTestCase ):
@@ -28,7 +32,76 @@ def myloss(x, y):
2832 self .assertRaise (lambda : ort_optimized_model (onx , "NO" ), ValueError )
2933 optimized = ort_optimized_model (onx )
3034 prof = ort_profile (optimized , feeds )
31- prof .to_csv ("prof.csv" , index = False )
35+ self .assertIsInstance (prof , DataFrame )
36+ prof = ort_profile (optimized , feeds , as_df = False )
37+ self .assertIsInstance (prof , list )
38+
39+ def test_ort_profile_first_it_out (self ):
40+ def l1_loss (x , y ):
41+ return absolute (x - y ).sum ()
42+
43+ def l2_loss (x , y ):
44+ return ((x - y ) ** 2 ).sum ()
45+
46+ def myloss (x , y ):
47+ return l1_loss (x [:, 0 ], y [:, 0 ]) + l2_loss (x [:, 1 ], y [:, 1 ])
48+
49+ jitted_myloss = jit_onnx (myloss )
50+ x = np .array ([[0.1 , 0.2 ], [0.3 , 0.4 ]], dtype = np .float32 )
51+ y = np .array ([[0.11 , 0.22 ], [0.33 , 0.44 ]], dtype = np .float32 )
52+ jitted_myloss (x , y )
53+ onx = jitted_myloss .get_onnx ()
54+ feeds = {"x0" : x , "x1" : y }
55+ self .assertRaise (lambda : ort_optimized_model (onx , "NO" ), ValueError )
56+ optimized = ort_optimized_model (onx )
57+ prof = ort_profile (optimized , feeds )
58+ events = {
59+ "kernel_time" ,
60+ "fence_before" ,
61+ "fence_after" ,
62+ "SequentialExecutor::Execute" ,
63+ "model_run" ,
64+ "model_loading_array" ,
65+ "session_initialization" ,
66+ }
67+ self .assertEqual (set (prof ["event_name" ]), events )
68+ agg = ort_profile (optimized , feeds , first_it_out = True , agg = True )
69+ self .assertIsInstance (agg , DataFrame )
70+ self .assertLess (agg .shape [0 ], prof .shape [0 ])
71+ self .assertEqual (set (agg .reset_index (drop = False )["event_name" ]), events )
72+ agg = ort_profile (
73+ optimized , feeds , first_it_out = True , agg = True , agg_op_name = False
74+ )
75+ self .assertIsInstance (agg , DataFrame )
76+ self .assertLess (agg .shape [0 ], prof .shape [0 ])
77+ self .assertEqual (set (agg .reset_index (drop = False )["event_name" ]), events )
78+
79+ def test_ort_profile_ort_value (self ):
80+ def to_ort_value (m ):
81+ device = C_OrtDevice (C_OrtDevice .cpu (), C_OrtDevice .default_memory (), 0 )
82+ ort_value = C_OrtValue .ortvalue_from_numpy (m , device )
83+ return ort_value
84+
85+ def l1_loss (x , y ):
86+ return absolute (x - y ).sum ()
87+
88+ def l2_loss (x , y ):
89+ return ((x - y ) ** 2 ).sum ()
90+
91+ def myloss (x , y ):
92+ return l1_loss (x [:, 0 ], y [:, 0 ]) + l2_loss (x [:, 1 ], y [:, 1 ])
93+
94+ jitted_myloss = jit_onnx (myloss )
95+ x = np .array ([[0.1 , 0.2 ], [0.3 , 0.4 ]], dtype = np .float32 )
96+ y = np .array ([[0.11 , 0.22 ], [0.33 , 0.44 ]], dtype = np .float32 )
97+ jitted_myloss (x , y )
98+ onx = jitted_myloss .get_onnx ()
99+ np_feeds = {"x0" : x , "x1" : y }
100+ feeds = {k : to_ort_value (v ) for k , v in np_feeds .items ()}
101+
102+ self .assertRaise (lambda : ort_optimized_model (onx , "NO" ), ValueError )
103+ optimized = ort_optimized_model (onx )
104+ prof = ort_profile (optimized , feeds )
32105 self .assertIsInstance (prof , DataFrame )
33106 prof = ort_profile (optimized , feeds , as_df = False )
34107 self .assertIsInstance (prof , list )
0 commit comments