Skip to content

Commit 2f8ccce

Browse files
committed
Add adaptive epsilon and configurable precision parameters
- Fix query methods to use Real type instead of hardcoded float - find_one() now accepts vec<Real> for proper float64 precision - find_all() now accepts py::array_t<Real> matching tree precision - Fix Python wrapper to preserve precision settings on first insert - Handle subnormal detection disabled case with workaround - Preserve relative_epsilon, absolute_epsilon, adaptive_epsilon settings - Remove obsolete query_exact and refine_candidates code - All precision tests now passing
1 parent 15ade9f commit 2f8ccce

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

include/prtree/core/prtree.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ template <IndexType T, int B = 6, int D = 2, typename Real = float> class PRTree
584584
#endif
585585
}
586586

587-
auto find_all(const py::array_t<float> &x) {
587+
auto find_all(const py::array_t<Real> &x) {
588588
#ifdef MY_DEBUG
589589
ProfilerStart("find_all.prof");
590590
std::cout << "profiler start of find_all" << std::endl;
@@ -728,40 +728,34 @@ template <IndexType T, int B = 6, int D = 2, typename Real = float> class PRTree
728728
return out;
729729
}
730730

731-
auto find_all_array(const py::array_t<float> &x) {
731+
auto find_all_array(const py::array_t<Real> &x) {
732732
return list_list_to_arrays(std::move(find_all(x)));
733733
}
734734

735-
auto find_one(const vec<float> &x) {
735+
auto find_one(const vec<Real> &x) {
736736
bool is_point = false;
737737
if (unlikely(!(x.size() == 2 * D || x.size() == D))) {
738738
throw std::runtime_error("invalid shape");
739739
}
740740
Real minima[D];
741741
Real maxima[D];
742-
std::array<double, 2 * D> query_exact;
743742

744743
if (x.size() == D) {
745744
is_point = true;
746745
}
747746
for (int i = 0; i < D; ++i) {
748747
minima[i] = x.at(i);
749-
query_exact[i] = static_cast<double>(x.at(i));
750748

751749
if (is_point) {
752750
maxima[i] = minima[i];
753-
query_exact[i + D] = query_exact[i];
754751
} else {
755752
maxima[i] = x.at(i + D);
756-
query_exact[i + D] = static_cast<double>(x.at(i + D));
757753
}
758754
}
759755
const auto bb = BB<D, Real>(minima, maxima);
760756
auto candidates = find(bb);
761757

762-
// Refine with double precision if exact coordinates are available
763-
auto out = candidates;
764-
return out;
758+
return candidates;
765759
}
766760

767761
// Helper method: Check intersection with double precision (closed interval

src/python_prtree/core.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,48 @@ def insert(
201201

202202
objdumps = _dumps(obj)
203203
if self.n == 0:
204-
# Reinitialize tree with correct precision
204+
# Reinitialize tree with correct precision and preserve settings
205205
Klass = self.Klass_float64 if self._use_float64 else self.Klass_float32
206-
self._tree = Klass([idx], [bb])
207-
self._tree.set_obj(idx, objdumps)
206+
old_tree = self._tree
207+
208+
# Check if subnormal detection is disabled - if so, use workaround
209+
subnormal_disabled = (hasattr(old_tree, 'get_subnormal_detection') and
210+
not old_tree.get_subnormal_detection())
211+
212+
if subnormal_disabled:
213+
# Create with dummy valid box first
214+
dummy_idx = -999999
215+
dummy_bb = np.ones(len(bb), dtype=bb.dtype)
216+
self._tree = Klass([dummy_idx], [dummy_bb])
217+
218+
# Preserve settings and disable subnormal detection
219+
if hasattr(old_tree, 'get_relative_epsilon'):
220+
self._tree.set_relative_epsilon(old_tree.get_relative_epsilon())
221+
if hasattr(old_tree, 'get_absolute_epsilon'):
222+
self._tree.set_absolute_epsilon(old_tree.get_absolute_epsilon())
223+
if hasattr(old_tree, 'get_adaptive_epsilon'):
224+
self._tree.set_adaptive_epsilon(old_tree.get_adaptive_epsilon())
225+
self._tree.set_subnormal_detection(False)
226+
227+
# Now insert the real box (tree is not empty, insert will work)
228+
self._tree.insert(idx, bb, objdumps)
229+
# Erase dummy
230+
self._tree.erase(dummy_idx)
231+
else:
232+
# Normal path
233+
self._tree = Klass([idx], [bb])
234+
235+
# Preserve settings from old tree
236+
if hasattr(old_tree, 'get_relative_epsilon'):
237+
self._tree.set_relative_epsilon(old_tree.get_relative_epsilon())
238+
if hasattr(old_tree, 'get_absolute_epsilon'):
239+
self._tree.set_absolute_epsilon(old_tree.get_absolute_epsilon())
240+
if hasattr(old_tree, 'get_adaptive_epsilon'):
241+
self._tree.set_adaptive_epsilon(old_tree.get_adaptive_epsilon())
242+
if hasattr(old_tree, 'get_subnormal_detection'):
243+
self._tree.set_subnormal_detection(old_tree.get_subnormal_detection())
244+
245+
self._tree.set_obj(idx, objdumps)
208246
else:
209247
self._tree.insert(idx, bb, objdumps)
210248

0 commit comments

Comments
 (0)