@@ -1009,16 +1009,22 @@ def neighbor_list_and_relative_vec(
10091009 shifts = shifts [mask ]
10101010
10111011 # 2. for i == j
1012-
1012+
10131013 mask = torch .ones (len (first_idex ), dtype = torch .bool )
10141014 mask [first_idex == second_idex ] = False
10151015 # get index bool type ~mask for i == j.
1016- o_first_idex = first_idex [~ mask ]
1017- o_second_idex = second_idex [~ mask ]
1018- o_shift = shifts [~ mask ]
1016+ # Convert mask to numpy for consistent indexing behavior
1017+ mask_np = mask .cpu ().numpy ()
1018+ o_first_idex = first_idex [~ mask_np ]
1019+ o_second_idex = second_idex [~ mask_np ]
1020+ o_shift = shifts [~ mask_np ]
10191021 o_mask = mask [~ mask ] # this is all False, with length being the number all the bonds with i == j.
10201022
1021-
1023+ # Ensure arrays are proper numpy arrays (not scalars) for isolated systems
1024+ o_first_idex = np .atleast_1d (o_first_idex )
1025+ o_second_idex = np .atleast_1d (o_second_idex )
1026+ o_shift = np .atleast_2d (o_shift )
1027+
10221028 # using the dict key to remove the duplicate bonds, because it is O(1) to check if a key is in the dict.
10231029 rev_dict = {}
10241030 for i in range (len (o_first_idex )):
@@ -1042,10 +1048,12 @@ def neighbor_list_and_relative_vec(
10421048 del o_shift
10431049 mask [~ mask ] = o_mask
10441050 del o_mask
1045-
1046- first_idex = torch .LongTensor (first_idex [mask ], device = out_device )
1047- second_idex = torch .LongTensor (second_idex [mask ], device = out_device )
1048- shifts = torch .as_tensor (shifts [mask ], dtype = out_dtype , device = out_device )
1051+
1052+ # Convert mask to numpy for indexing numpy arrays (avoids torch/numpy compatibility issues)
1053+ mask_np = mask .cpu ().numpy ()
1054+ first_idex = torch .as_tensor (first_idex [mask_np ], dtype = torch .long , device = out_device )
1055+ second_idex = torch .as_tensor (second_idex [mask_np ], dtype = torch .long , device = out_device )
1056+ shifts = torch .as_tensor (shifts [mask_np ], dtype = out_dtype , device = out_device )
10491057
10501058 if not reduce :
10511059 assert self_interaction == False , "for self_interaction = True, i i 0 0 0 will be duplicated."
0 commit comments