This repository was archived by the owner on Apr 28, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
tensor_comprehensions/pybinds Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -273,6 +273,17 @@ struct TcExecutor {
273273 return tupleOrTensor (convertToPyObjects (atOutputs));
274274 }
275275 }
276+
277+ size_t profile_kernel (const py::tuple& inputs, const py::tuple& outputs) {
278+ auto atInputs = getATenTensors (inputs);
279+ auto atOutputs = (outputs.size () > 0 )
280+ ? getATenTensors (outputs)
281+ : tc::aten::prepareOutputs (tc, entryPoint, atInputs);
282+ tc::ProfilingInfo profinfo =
283+ tc::aten::profile (*executor, atInputs, atOutputs);
284+ return profinfo.kernelRuntime .toMicroSeconds ();
285+ }
286+
276287 std::string tc;
277288 std::string entryPoint;
278289 std::unique_ptr<tc::CudaBackend::ExecutorType> executor;
@@ -465,7 +476,13 @@ PYBIND11_MODULE(tclib, m) {
465476 " unchecked_run" ,
466477 &TcExecutor::uncheckedRun,
467478 py::arg (" inputs" ),
479+ py::arg (" outputs" ) = py::tuple ())
480+ .def (
481+ " profile_kernel" ,
482+ &TcExecutor::profile_kernel,
483+ py::arg (" inputs" ),
468484 py::arg (" outputs" ) = py::tuple ());
485+
469486 m.def (
470487 " compile" ,
471488 [](const std::string& tc,
You can’t perform that action at this time.
0 commit comments