@@ -182,6 +182,20 @@ class ResidualCpp {
182182 residual_->OverwriteData (data_ptr, num_row);
183183 }
184184
185+ void AddToData (py::array_t <double > update_vector, data_size_t num_row) {
186+ // Extract pointer to contiguous block of memory
187+ double * data_ptr = static_cast <double *>(update_vector.mutable_data ());
188+ // Add to data in residual_
189+ residual_->AddToData (data_ptr, num_row);
190+ }
191+
192+ void SubtractFromData (py::array_t <double > update_vector, data_size_t num_row) {
193+ // Extract pointer to contiguous block of memory
194+ double * data_ptr = static_cast <double *>(update_vector.mutable_data ());
195+ // Subtract from data in residual_
196+ residual_->SubtractFromData (data_ptr, num_row);
197+ }
198+
185199 private:
186200 std::unique_ptr<StochTree::ColumnVector> residual_;
187201};
@@ -2224,7 +2238,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
22242238 py::class_<ResidualCpp>(m, " ResidualCpp" )
22252239 .def (py::init<py::array_t <double >,data_size_t >())
22262240 .def (" GetResidualArray" , &ResidualCpp::GetResidualArray)
2227- .def (" ReplaceData" , &ResidualCpp::ReplaceData);
2241+ .def (" ReplaceData" , &ResidualCpp::ReplaceData)
2242+ .def (" AddToData" , &ResidualCpp::AddToData)
2243+ .def (" SubtractFromData" , &ResidualCpp::SubtractFromData);
22282244
22292245 py::class_<RngCpp>(m, " RngCpp" )
22302246 .def (py::init<int >());
0 commit comments