@@ -341,3 +341,69 @@ def test_get_component_devices_from_composite():
341341 assert d .has_aspect_is_component
342342 # component devices are root devices
343343 assert d in devices
344+
345+
346+ @pytest .mark .parametrize ("platform_name" , ["level_zero" , "cuda" , "hip" ])
347+ def test_can_access_peer (platform_name ):
348+ """
349+ Test checks for peer access.
350+ """
351+ try :
352+ platform = dpctl .SyclPlatform (platform_name )
353+ except ValueError as e :
354+ pytest .skip (f"{ str (e )} { platform_name } " )
355+ devices = platform .get_devices ()
356+ if len (devices ) < 2 :
357+ pytest .skip (
358+ f"Platform { platform_name } does not have enough devices to "
359+ "test peer access"
360+ )
361+ dev0 = devices [0 ]
362+ dev1 = devices [1 ]
363+ assert isinstance (dev0 .can_access_peer_access_supported (dev1 ), bool )
364+ assert isinstance (dev0 .can_access_peer_atomics_supported (dev1 ), bool )
365+
366+
367+ @pytest .mark .parametrize ("platform_name" , ["level_zero" , "cuda" , "hip" ])
368+ def test_enable_disable_peer (platform_name ):
369+ """
370+ Test that peer access can be enabled and disabled.
371+ """
372+ try :
373+ platform = dpctl .SyclPlatform (platform_name )
374+ except ValueError as e :
375+ pytest .skip (f"{ str (e )} { platform_name } " )
376+ devices = platform .get_devices ()
377+ if len (devices ) < 2 :
378+ pytest .skip (
379+ f"Platform { platform_name } does not have enough devices to "
380+ "test peer access"
381+ )
382+ dev0 = devices [0 ]
383+ dev1 = devices [1 ]
384+ if dev0 .can_access_peer_access_supported (dev1 ):
385+ dev0 .enable_peer_access (dev1 )
386+ dev0 .disable_peer_access (dev1 )
387+ else :
388+ pytest .skip (
389+ f"Provided { platform_name } devices do not support peer access"
390+ )
391+
392+
393+ def test_peer_device_arg_validation ():
394+ """
395+ Test for validation of arguments to peer access related methods.
396+ """
397+ try :
398+ dev = dpctl .SyclDevice ()
399+ except dpctl .SyclDeviceCreationError :
400+ pytest .skip ("No default device available" )
401+ bad_dev = dict ()
402+ with pytest .raises (TypeError ):
403+ dev .can_access_peer_access_supported (bad_dev )
404+ with pytest .raises (TypeError ):
405+ dev .can_access_peer_atomics_supported (bad_dev )
406+ with pytest .raises (TypeError ):
407+ dev .enable_peer_access (bad_dev )
408+ with pytest .raises (TypeError ):
409+ dev .disable_peer_access (bad_dev )
0 commit comments