Skip to content

Commit fd65c44

Browse files
Fix: fix AtomicData.py mask handling and ensure proper numpy arrays (#286)
* fix: Convert mask to numpy and Ensure arrays are proper numpy * fix: Ensure o_shift is a 2D numpy array for isolated systems * fix: Use .cpu().numpy() for mask conversion to ensure compatibility with numpy arrays
1 parent bd6677a commit fd65c44

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

dptb/data/AtomicData.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,16 +1009,22 @@ def neighbor_list_and_relative_vec(
10091009
shifts = shifts[mask]
10101010

10111011
# 2. for i == j
1012-
1012+
10131013
mask = torch.ones(len(first_idex), dtype=torch.bool)
10141014
mask[first_idex == second_idex] = False
10151015
# get index bool type ~mask for i == j.
1016-
o_first_idex = first_idex[~mask]
1017-
o_second_idex = second_idex[~mask]
1018-
o_shift = shifts[~mask]
1016+
# Convert mask to numpy for consistent indexing behavior
1017+
mask_np = mask.cpu().numpy()
1018+
o_first_idex = first_idex[~mask_np]
1019+
o_second_idex = second_idex[~mask_np]
1020+
o_shift = shifts[~mask_np]
10191021
o_mask = mask[~mask] # this is all False, with length being the number all the bonds with i == j.
10201022

1021-
1023+
# Ensure arrays are proper numpy arrays (not scalars) for isolated systems
1024+
o_first_idex = np.atleast_1d(o_first_idex)
1025+
o_second_idex = np.atleast_1d(o_second_idex)
1026+
o_shift = np.atleast_2d(o_shift)
1027+
10221028
# using the dict key to remove the duplicate bonds, because it is O(1) to check if a key is in the dict.
10231029
rev_dict = {}
10241030
for i in range(len(o_first_idex)):
@@ -1042,10 +1048,12 @@ def neighbor_list_and_relative_vec(
10421048
del o_shift
10431049
mask[~mask] = o_mask
10441050
del o_mask
1045-
1046-
first_idex = torch.LongTensor(first_idex[mask], device=out_device)
1047-
second_idex = torch.LongTensor(second_idex[mask], device=out_device)
1048-
shifts = torch.as_tensor(shifts[mask], dtype=out_dtype, device=out_device)
1051+
1052+
# Convert mask to numpy for indexing numpy arrays (avoids torch/numpy compatibility issues)
1053+
mask_np = mask.cpu().numpy()
1054+
first_idex = torch.as_tensor(first_idex[mask_np], dtype=torch.long, device=out_device)
1055+
second_idex = torch.as_tensor(second_idex[mask_np], dtype=torch.long, device=out_device)
1056+
shifts = torch.as_tensor(shifts[mask_np], dtype=out_dtype, device=out_device)
10491057

10501058
if not reduce:
10511059
assert self_interaction == False, "for self_interaction = True, i i 0 0 0 will be duplicated."

0 commit comments

Comments
 (0)