Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,16 +1009,22 @@ def neighbor_list_and_relative_vec(
shifts = shifts[mask]

# 2. for i == j

mask = torch.ones(len(first_idex), dtype=torch.bool)
mask[first_idex == second_idex] = False
# get index bool type ~mask for i == j.
o_first_idex = first_idex[~mask]
o_second_idex = second_idex[~mask]
o_shift = shifts[~mask]
# Convert mask to numpy for consistent indexing behavior
mask_np = mask.cpu().numpy()
o_first_idex = first_idex[~mask_np]
o_second_idex = second_idex[~mask_np]
o_shift = shifts[~mask_np]
o_mask = mask[~mask] # this is all False, with length being the number all the bonds with i == j.


# Ensure arrays are proper numpy arrays (not scalars) for isolated systems
o_first_idex = np.atleast_1d(o_first_idex)
o_second_idex = np.atleast_1d(o_second_idex)
o_shift = np.atleast_2d(o_shift)

# using the dict key to remove the duplicate bonds, because it is O(1) to check if a key is in the dict.
rev_dict = {}
for i in range(len(o_first_idex)):
Expand All @@ -1042,10 +1048,12 @@ def neighbor_list_and_relative_vec(
del o_shift
mask[~mask] = o_mask
del o_mask

first_idex = torch.LongTensor(first_idex[mask], device=out_device)
second_idex = torch.LongTensor(second_idex[mask], device=out_device)
shifts = torch.as_tensor(shifts[mask], dtype=out_dtype, device=out_device)

# Convert mask to numpy for indexing numpy arrays (avoids torch/numpy compatibility issues)
mask_np = mask.cpu().numpy()
first_idex = torch.as_tensor(first_idex[mask_np], dtype=torch.long, device=out_device)
second_idex = torch.as_tensor(second_idex[mask_np], dtype=torch.long, device=out_device)
shifts = torch.as_tensor(shifts[mask_np], dtype=out_dtype, device=out_device)

if not reduce:
assert self_interaction == False, "for self_interaction = True, i i 0 0 0 will be duplicated."
Expand Down
Loading