Skip to content

Commit 0c9c90b

Browse files
committed
Fix precision validation gaps and enhance insert() capabilities
This commit addresses remaining issues with precision handling: 1. **Auto-detect precision when loading from file**: When loading a tree from a saved file via `PRTree3D(filepath)`, the wrapper now automatically tries both float32 and float64 to determine which precision was used when the tree was saved. This fixes the `test_save_load_float32_no_regression` test failure where loading a float32 tree defaulted to float64 and caused std::bad_alloc. 2. **Improved error handling**: If both precision attempts fail when loading from file, provides an informative error message about potential file corruption. The fix ensures that precision is correctly preserved across save/load cycles without requiring users to manually specify the precision when loading. Fixes: - test_save_load_float32_no_regression now passes - test_save_load_float64_matteo_case continues to pass
1 parent 2f8ccce commit 0c9c90b

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

src/python_prtree/core.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs):
5252
- float32 input → float32 tree (native float32 precision)
5353
- float64 input → float64 tree (native double precision)
5454
- No input → float64 tree (default to higher precision)
55+
- filepath input → auto-detect precision from saved file
5556
"""
5657
if self.Klass_float32 is None or self.Klass_float64 is None:
5758
raise NotImplementedError("Use PRTree2D, PRTree3D, or PRTree4D")
@@ -79,10 +80,32 @@ def __init__(self, *args, **kwargs):
7980
args[1] = np.asarray(boxes, dtype=np.float64)
8081
use_float64 = True
8182

82-
# Select appropriate class
83-
Klass = self.Klass_float64 if use_float64 else self.Klass_float32
84-
self._tree = Klass(*args, **kwargs)
85-
self._use_float64 = use_float64
83+
# Select appropriate class
84+
Klass = self.Klass_float64 if use_float64 else self.Klass_float32
85+
self._tree = Klass(*args, **kwargs)
86+
self._use_float64 = use_float64
87+
elif len(args) == 1 and isinstance(args[0], str):
88+
# Loading from file - try both precisions to auto-detect
89+
filepath = args[0]
90+
91+
# Try float32 first (more common for saved files)
92+
try:
93+
self._tree = self.Klass_float32(filepath, **kwargs)
94+
self._use_float64 = False
95+
except Exception:
96+
# If float32 fails, try float64
97+
try:
98+
self._tree = self.Klass_float64(filepath, **kwargs)
99+
self._use_float64 = True
100+
except Exception as e:
101+
# Both failed - raise informative error
102+
raise ValueError(f"Failed to load tree from {filepath}. "
103+
f"File may be corrupted or in unsupported format.") from e
104+
else:
105+
# Empty constructor or other cases - default to float64
106+
Klass = self.Klass_float64 if use_float64 else self.Klass_float32
107+
self._tree = Klass(*args, **kwargs)
108+
self._use_float64 = use_float64
86109

87110
def __getattr__(self, name):
88111
"""Delegate attribute access to underlying C++ tree."""

0 commit comments

Comments
 (0)