This repository was archived by the owner on Apr 28, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
tensor_comprehensions/pybinds Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -274,6 +274,17 @@ struct TcExecutor {
274274 return tupleOrTensor (convertToPyObjects (atOutputs));
275275 }
276276 }
277+
278+ size_t profile_kernel (const py::tuple& inputs, const py::tuple& outputs) {
279+ auto atInputs = getATenTensors (inputs);
280+ auto atOutputs = (outputs.size () > 0 )
281+ ? getATenTensors (outputs)
282+ : tc::aten::prepareOutputs (tc, entryPoint, atInputs);
283+ tc::ProfilingInfo profinfo =
284+ tc::aten::profile (*executor, atInputs, atOutputs);
285+ return profinfo.kernelRuntime .toMicroSeconds ();
286+ }
287+
277288 std::string tc;
278289 std::string entryPoint;
279290 std::unique_ptr<tc::CudaBackend::ExecutorType> executor;
@@ -485,7 +496,13 @@ PYBIND11_MODULE(tclib, m) {
485496 " unchecked_run" ,
486497 &TcExecutor::uncheckedRun,
487498 py::arg (" inputs" ),
499+ py::arg (" outputs" ) = py::tuple ())
500+ .def (
501+ " profile_kernel" ,
502+ &TcExecutor::profile_kernel,
503+ py::arg (" inputs" ),
488504 py::arg (" outputs" ) = py::tuple ());
505+
489506 m.def (
490507 " compile" ,
491508 [](const std::string& tc,
You can’t perform that action at this time.
0 commit comments