Skip to content

Commit b071c86

Browse files
authored
Merge pull request #235 from StochasticTree/python-ui-parity
Update Python data interface for parity with the R interface
2 parents 92eb863 + c956b79 commit b071c86

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

src/py_stochtree.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,20 @@ class ResidualCpp {
182182
residual_->OverwriteData(data_ptr, num_row);
183183
}
184184

185+
void AddToData(py::array_t<double> update_vector, data_size_t num_row) {
186+
// Extract pointer to contiguous block of memory
187+
double* data_ptr = static_cast<double*>(update_vector.mutable_data());
188+
// Add to data in residual_
189+
residual_->AddToData(data_ptr, num_row);
190+
}
191+
192+
void SubtractFromData(py::array_t<double> update_vector, data_size_t num_row) {
193+
// Extract pointer to contiguous block of memory
194+
double* data_ptr = static_cast<double*>(update_vector.mutable_data());
195+
// Subtract from data in residual_
196+
residual_->SubtractFromData(data_ptr, num_row);
197+
}
198+
185199
private:
186200
std::unique_ptr<StochTree::ColumnVector> residual_;
187201
};
@@ -2224,7 +2238,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
22242238
py::class_<ResidualCpp>(m, "ResidualCpp")
22252239
.def(py::init<py::array_t<double>,data_size_t>())
22262240
.def("GetResidualArray", &ResidualCpp::GetResidualArray)
2227-
.def("ReplaceData", &ResidualCpp::ReplaceData);
2241+
.def("ReplaceData", &ResidualCpp::ReplaceData)
2242+
.def("AddToData", &ResidualCpp::AddToData)
2243+
.def("SubtractFromData", &ResidualCpp::SubtractFromData);
22282244

22292245
py::class_<RngCpp>(m, "RngCpp")
22302246
.def(py::init<int>());

stochtree/data.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,37 @@ def update_data(self, new_vector: np.array) -> None:
264264
"""
265265
n = new_vector.size
266266
self.residual_cpp.ReplaceData(new_vector, n)
267+
268+
def add_vector(self, update_vector: np.array) -> None:
269+
"""
270+
Update the current state of the outcome (i.e. partial residual) data by adding each element of `update_vector`
271+
272+
Parameters
273+
----------
274+
update_vector : np.array
275+
Univariate numpy array of values to add to the current residual.
276+
"""
277+
if not isinstance(update_vector, np.ndarray):
278+
raise ValueError("update_vector must be a numpy array.")
279+
update_vector_ = np.squeeze(update_vector)
280+
if not update_vector_.ndim == 1:
281+
raise ValueError("update_vector must be a 1-dimensional numpy array.")
282+
n = update_vector_.size
283+
self.residual_cpp.AddToData(update_vector_, n)
284+
285+
def subtract_vector(self, update_vector: np.array) -> None:
286+
"""
287+
Update the current state of the outcome (i.e. partial residual) data by subtracting each element of `update_vector`
288+
289+
Parameters
290+
----------
291+
update_vector : np.array
292+
Univariate numpy array of values to subtracted from the current residual.
293+
"""
294+
if not isinstance(update_vector, np.ndarray):
295+
raise ValueError("update_vector must be a numpy array.")
296+
update_vector_ = np.squeeze(update_vector)
297+
if not update_vector_.ndim == 1:
298+
raise ValueError("update_vector must be a 1-dimensional numpy array.")
299+
n = update_vector_.size
300+
self.residual_cpp.SubtractFromData(update_vector_, n)

0 commit comments

Comments
 (0)