@@ -239,6 +239,43 @@ cdef inline bint _check_peer_access(SyclDevice dev, SyclDevice peer) except *:
239239 return False
240240
241241
242+ cdef inline void _raise_invalid_peer_access(
243+ SyclDevice dev,
244+ SyclDevice peer,
245+ ) except * :
246+ """
247+ Check peer access ahead of time and raise errors for invalid cases.
248+ """
249+ cdef list _peer_access_backends = [
250+ _backend_type._CUDA,
251+ _backend_type._HIP,
252+ _backend_type._LEVEL_ZERO
253+ ]
254+ cdef _backend_type BTy1 = DPCTLDevice_GetBackend(dev._device_ref)
255+ cdef _backend_type BTy2 = DPCTLDevice_GetBackend(peer.get_device_ref())
256+ if (BTy1 != BTy2):
257+ raise ValueError (
258+ f" Device with backend {_backend_type_to_filter_string_part(BTy1)} "
259+ " cannot peer access device with backend "
260+ f" {_backend_type_to_filter_string_part(BTy2)}"
261+ )
262+ if (BTy1 not in _peer_access_backends):
263+ raise ValueError (
264+ " Peer access not supported for backend "
265+ f" {_backend_type_to_filter_string_part(BTy1)}"
266+ )
267+ if (BTy2 not in _peer_access_backends):
268+ raise ValueError (
269+ " Peer access not supported for backend "
270+ f" {_backend_type_to_filter_string_part(BTy2)}"
271+ )
272+ if (dev == peer):
273+ raise ValueError (
274+ " Peer access cannot be enabled between a device and itself"
275+ )
276+ return
277+
278+
242279@ functools.lru_cache (maxsize = None )
243280def _cached_filter_string (d : SyclDevice ):
244281 """
@@ -1850,7 +1887,6 @@ cdef class SyclDevice(_SyclDevice):
18501887 f" {type(peer)}"
18511888 )
18521889 p_dev = < SyclDevice> peer
1853-
18541890 if _check_peer_access(self , p_dev):
18551891 return DPCTLDevice_CanAccessPeer(
18561892 self ._device_ref,
@@ -1893,7 +1929,6 @@ cdef class SyclDevice(_SyclDevice):
18931929 f" {type(peer)}"
18941930 )
18951931 p_dev = < SyclDevice> peer
1896-
18971932 if _check_peer_access(self , p_dev):
18981933 return DPCTLDevice_CanAccessPeer(
18991934 self ._device_ref,
@@ -1931,14 +1966,11 @@ cdef class SyclDevice(_SyclDevice):
19311966 f" {type(peer)}"
19321967 )
19331968 p_dev = < SyclDevice> peer
1934-
1935- if _check_peer_access(self , p_dev):
1936- DPCTLDevice_EnablePeerAccess(
1937- self ._device_ref,
1938- p_dev.get_device_ref()
1939- )
1940- else :
1941- raise ValueError (" Peer access cannot be enabled for these devices" )
1969+ _raise_invalid_peer_access(self , p_dev)
1970+ DPCTLDevice_EnablePeerAccess(
1971+ self ._device_ref,
1972+ p_dev.get_device_ref()
1973+ )
19421974 return
19431975
19441976 def disable_peer_access (self , peer ):
@@ -1969,14 +2001,12 @@ cdef class SyclDevice(_SyclDevice):
19692001 f" {type(peer)}"
19702002 )
19712003 p_dev = < SyclDevice> peer
1972-
2004+ _raise_invalid_peer_access( self , p_dev)
19732005 if _check_peer_access(self , p_dev):
19742006 DPCTLDevice_DisablePeerAccess(
19752007 self ._device_ref,
19762008 p_dev.get_device_ref()
19772009 )
1978- else :
1979- raise ValueError (" Peer access cannot be enabled for these devices" )
19802010 return
19812011
19822012 @property
0 commit comments