@@ -903,6 +903,24 @@ DPCTLDevice_GetCompositeDevice(__dpctl_keep const DPCTLSyclDeviceRef DRef)
903903 return nullptr ;
904904}
905905
906+ bool _CallPeerAccess (device dev, device peer)
907+ {
908+ auto BE1 = dev.get_backend ();
909+ auto BE2 = peer.get_backend ();
910+
911+ if ((BE1 != sycl::backend::ext_oneapi_level_zero &&
912+ BE1 != sycl::backend::ext_oneapi_cuda &&
913+ BE1 != sycl::backend::ext_oneapi_hip) ||
914+ (BE2 != sycl::backend::ext_oneapi_level_zero &&
915+ BE2 != sycl::backend::ext_oneapi_cuda &&
916+ BE2 != sycl::backend::ext_oneapi_hip) ||
917+ (dev == peer))
918+ {
919+ return false ;
920+ }
921+ return true ;
922+ }
923+
906924bool DPCTLDevice_CanAccessPeer (__dpctl_keep const DPCTLSyclDeviceRef DRef,
907925 __dpctl_keep const DPCTLSyclDeviceRef PDRef,
908926 DPCTLPeerAccessType PT)
@@ -911,33 +929,13 @@ bool DPCTLDevice_CanAccessPeer(__dpctl_keep const DPCTLSyclDeviceRef DRef,
911929 auto D = unwrap<device>(DRef);
912930 auto PD = unwrap<device>(PDRef);
913931 if (D && PD) {
914- auto BE1 = D->get_backend ();
915- auto BE2 = PD->get_backend ();
916-
917- if (BE1 != sycl::backend::ext_oneapi_level_zero &&
918- BE1 != sycl::backend::ext_oneapi_cuda &&
919- BE1 != sycl::backend::ext_oneapi_hip)
920- {
921- std::ostringstream os;
922- os << " Backend " << BE1 << " does not support peer access" ;
923- error_handler (os.str (), __FILE__, __func__, __LINE__);
924- return false ;
925- }
926-
927- if (BE2 != sycl::backend::ext_oneapi_level_zero &&
928- BE2 != sycl::backend::ext_oneapi_cuda &&
929- BE2 != sycl::backend::ext_oneapi_hip)
930- {
931- std::ostringstream os;
932- os << " Backend " << BE2 << " does not support peer access" ;
933- error_handler (os.str (), __FILE__, __func__, __LINE__);
934- return false ;
935- }
936- try {
937- canAccess = D->ext_oneapi_can_access_peer (
938- *PD, DPCTL_DPCTLPeerAccessTypeToSycl (PT));
939- } catch (std::exception const &e) {
940- error_handler (e, __FILE__, __func__, __LINE__);
932+ if (_CallPeerAccess (*D, *PD)) {
933+ try {
934+ canAccess = D->ext_oneapi_can_access_peer (
935+ *PD, DPCTL_DPCTLPeerAccessTypeToSycl (PT));
936+ } catch (std::exception const &e) {
937+ error_handler (e, __FILE__, __func__, __LINE__);
938+ }
941939 }
942940 }
943941 return canAccess;
@@ -949,31 +947,18 @@ void DPCTLDevice_EnablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
949947 auto D = unwrap<device>(DRef);
950948 auto PD = unwrap<device>(PDRef);
951949 if (D && PD) {
952- auto BE1 = D->get_backend ();
953- auto BE2 = PD->get_backend ();
954-
955- if (BE1 != sycl::backend::ext_oneapi_level_zero &&
956- BE1 != sycl::backend::ext_oneapi_cuda &&
957- BE1 != sycl::backend::ext_oneapi_hip)
958- {
959- std::ostringstream os;
960- os << " Backend " << BE1 << " does not support peer access" ;
961- error_handler (os.str (), __FILE__, __func__, __LINE__);
950+ if (_CallPeerAccess (*D, *PD)) {
951+ try {
952+ D->ext_oneapi_enable_peer_access (*PD);
953+ } catch (std::exception const &e) {
954+ error_handler (e, __FILE__, __func__, __LINE__);
955+ }
962956 }
963-
964- if (BE2 != sycl::backend::ext_oneapi_level_zero &&
965- BE2 != sycl::backend::ext_oneapi_cuda &&
966- BE2 != sycl::backend::ext_oneapi_hip)
967- {
957+ else {
968958 std::ostringstream os;
969- os << " Backend " << BE2 << " does not support peer access" ;
959+ os << " Given devices do not support peer access" ;
970960 error_handler (os.str (), __FILE__, __func__, __LINE__);
971961 }
972- try {
973- D->ext_oneapi_enable_peer_access (*PD);
974- } catch (std::exception const &e) {
975- error_handler (e, __FILE__, __func__, __LINE__);
976- }
977962 }
978963 return ;
979964}
@@ -984,31 +969,18 @@ void DPCTLDevice_DisablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
984969 auto D = unwrap<device>(DRef);
985970 auto PD = unwrap<device>(PDRef);
986971 if (D && PD) {
987- auto BE1 = D->get_backend ();
988- auto BE2 = PD->get_backend ();
989-
990- if (BE1 != sycl::backend::ext_oneapi_level_zero &&
991- BE1 != sycl::backend::ext_oneapi_cuda &&
992- BE1 != sycl::backend::ext_oneapi_hip)
993- {
994- std::ostringstream os;
995- os << " Backend " << BE1 << " does not support peer access" ;
996- error_handler (os.str (), __FILE__, __func__, __LINE__);
972+ if (_CallPeerAccess (*D, *PD)) {
973+ try {
974+ D->ext_oneapi_disable_peer_access (*PD);
975+ } catch (std::exception const &e) {
976+ error_handler (e, __FILE__, __func__, __LINE__);
977+ }
997978 }
998-
999- if (BE2 != sycl::backend::ext_oneapi_level_zero &&
1000- BE2 != sycl::backend::ext_oneapi_cuda &&
1001- BE2 != sycl::backend::ext_oneapi_hip)
1002- {
979+ else {
1003980 std::ostringstream os;
1004- os << " Backend " << BE2 << " does not support peer access" ;
981+ os << " Given devices do not support peer access" ;
1005982 error_handler (os.str (), __FILE__, __func__, __LINE__);
1006983 }
1007- try {
1008- D->ext_oneapi_disable_peer_access (*PD);
1009- } catch (std::exception const &e) {
1010- error_handler (e, __FILE__, __func__, __LINE__);
1011- }
1012984 }
1013985 return ;
1014986}
0 commit comments