@@ -49,7 +49,9 @@ class TcOp : public Operator<Context> {
4949 OperatorBase::GetSingleArgument<std::string>(" tcGradDef" , " ERROR" );
5050 gradTcName_ =
5151 OperatorBase::GetSingleArgument<std::string>(" tcGradName" , " ERROR" );
52- profile_ = OperatorBase::GetSingleArgument<bool >(" profile" , false );
52+ checkSizes_ = OperatorBase::GetSingleArgument<bool >(" checkSizes" , false );
53+ compiled_ = false ;
54+ handle_ = 0 ;
5355 ArgumentHelper args (operator_def);
5456 if (args.HasArgument (" mappingOptions" )) {
5557 cudaMappingOptions_ = tc::CudaMappingOptions (
@@ -95,38 +97,37 @@ class TcOp : public Operator<Context> {
9597 }
9698
9799 virtual bool RunOnDevice () override {
98- // first, given the TC, define it in the executionEngine_
99- executionEngine_->define (tc_);
100-
101- // now, given the input tensors, convert them to dlpack tensors so that
102- // we can call the compile command
103- std::vector<::tc::dlutils::DLTensorUPtr> inTensorUPtrs;
104- std::vector<const DLTensor*> inputDLTensors;
105- for (int idx = 0 ; idx < this ->InputSize (); ++idx) {
106- auto dims = this ->Input (idx).dims ();
107- inTensorUPtrs.emplace_back (
108- dlpack::makeConstDLTensor (this ->Input (idx), dims));
109- inputDLTensors.push_back (inTensorUPtrs.back ().get ());
100+ if (!compiled_) {
101+ // first, given the TC, define it in the executionEngine_
102+ executionEngine_->define (tc_);
103+ for (int idx = 0 ; idx < this ->InputSize (); ++idx) {
104+ auto dims = this ->Input (idx).dims ();
105+ inTensorUPtrs_.emplace_back (
106+ dlpack::makeConstDLTensor (this ->Input (idx), dims));
107+ inputDLTensors_.push_back (inTensorUPtrs_[idx].get ());
108+ inputVoidPtrs_.push_back (inputDLTensors_[idx]->data );
109+ }
110+ auto outTensorInfo =
111+ executionEngine_->inferOutputTensorInfo (tcName_, inputDLTensors_);
112+ prepareOutputs (outTensorInfo);
113+ for (int idx = 0 ; idx < OutputSize (); ++idx) {
114+ outTensorUPtrs_.emplace_back (dlpack::makeDLTensor (Output (idx)));
115+ outputDLTensors_.push_back (outTensorUPtrs_[idx].get ());
116+ outputVoidPtrs_.push_back (outputDLTensors_[idx]->data );
117+ }
118+ handle_ = executionEngine_->compile (
119+ tcName_,
120+ inputDLTensors_,
121+ cudaMappingOptions_.toProtobufSerializedString ());
122+ compiled_ = true ;
110123 }
111124
112- auto outTensorInfo =
113- executionEngine_->inferOutputTensorInfo (tcName_, inputDLTensors);
114- prepareOutputs (outTensorInfo);
115-
116- // now create the outputDLTensors
117- std::vector<::tc::dlutils::DLTensorUPtr> outTensorUPtrs;
118- std::vector<DLTensor*> outputDLTensors;
119- for (int i = 0 ; i < OutputSize (); ++i) {
120- outTensorUPtrs.emplace_back (dlpack::makeDLTensor (Output (i)));
121- outputDLTensors.push_back (outTensorUPtrs.back ().get ());
125+ if (checkSizes_) {
126+ executionEngine_->run (handle_, inputDLTensors_, outputDLTensors_);
127+ } else {
128+ executionEngine_->uncheckedRun (handle_, inputVoidPtrs_, outputVoidPtrs_);
122129 }
123130
124- // compile and run
125- auto handle = executionEngine_->compile (
126- tcName_,
127- inputDLTensors,
128- cudaMappingOptions_.toProtobufSerializedString ());
129- executionEngine_->run (handle, inputDLTensors, outputDLTensors, profile_);
130131 return true ;
131132 }
132133
@@ -135,7 +136,15 @@ class TcOp : public Operator<Context> {
135136 std::string gradTc_;
136137 std::string tcName_;
137138 std::string gradTcName_;
138- bool profile_;
139+ bool checkSizes_;
140+ bool compiled_;
141+ size_t handle_;
142+ std::vector<const void *> inputVoidPtrs_;
143+ std::vector<void *> outputVoidPtrs_;
144+ std::vector<const DLTensor*> inputDLTensors_;
145+ std::vector<DLTensor*> outputDLTensors_;
146+ std::vector<::tc::dlutils::DLTensorUPtr> inTensorUPtrs_;
147+ std::vector<::tc::dlutils::DLTensorUPtr> outTensorUPtrs_;
139148 tc::CudaMappingOptions cudaMappingOptions_;
140149 tc::CudaMappingOptions gradCudaMappingOptions_;
141150
0 commit comments