-
Notifications
You must be signed in to change notification settings - Fork 224
Fixing GPU Adapter Count test to be more dynamic and fail resistent #4038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
01b3d6c
57121bd
1750183
c2aaf45
7ce447f
51351e7
69c3735
027528a
29269e8
378e101
d6b7bb9
810ac73
bbe2fc4
2eee4a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| from functools import partial | ||
| from typing import Any, List, Type | ||
| from typing import Any, Dict, List, Type | ||
|
|
||
| from dataclasses_json import dataclass_json | ||
|
|
||
|
|
@@ -136,17 +136,121 @@ def install_compute_sdk(self, version: str = "") -> None: | |
| raise LisaException(f"{driver} is not a valid value of ComputeSDK") | ||
|
|
||
| def get_gpu_count_with_lsvmbus(self) -> int: | ||
| """ | ||
| Count GPU devices using lsvmbus. | ||
| First tries known list, then groups devices by last segment of device ID. | ||
| """ | ||
| lsvmbus_tool = self._node.tools[Lsvmbus] | ||
|
|
||
| # Get all VMBus devices | ||
| vmbus_devices = lsvmbus_tool.get_device_channels() | ||
| self._log.debug(f"Found {len(vmbus_devices)} VMBus devices") | ||
|
|
||
| # First try the known list (original approach) | ||
| gpu_count = self._get_gpu_count_from_known_list(vmbus_devices) | ||
|
|
||
| if gpu_count > 0: | ||
| self._log.debug(f"Found {gpu_count} GPU(s) using known list") | ||
| return gpu_count | ||
|
|
||
| # If no matches in known list, group by last segment | ||
| self._log.debug("No GPUs found in known list, trying last-segment grouping") | ||
| gpu_count = self._get_gpu_count_by_device_id_segment(vmbus_devices) | ||
|
|
||
| if gpu_count > 0: | ||
| self._log.debug(f"Found {gpu_count} GPU(s) using last-segment grouping") | ||
| else: | ||
| self._log.debug("No GPU devices found in lsvmbus") | ||
|
|
||
| return gpu_count | ||
|
|
||
| def _get_gpu_count_by_device_id_segment(self, vmbus_devices: List[Any]) -> int: | ||
|
||
| """ | ||
| Group VMBus devices by last segment of device ID and find GPU group. | ||
| GPUs typically share the same last segment (e.g., '423331303142' for GB200). | ||
| """ | ||
| try: | ||
| # Get actual GPU count from nvidia-smi | ||
| nvidia_smi = self._node.tools[NvidiaSmi] | ||
| # Get GPU count from nvidia-smi without using pre-existing list | ||
| actual_gpu_count = nvidia_smi.get_gpu_count(known_only=False) | ||
|
|
||
| if actual_gpu_count == 0: | ||
| self._log.debug("nvidia-smi reports 0 GPUs") | ||
| return 0 | ||
|
|
||
| self._log.debug(f"nvidia-smi reports {actual_gpu_count} GPU(s)") | ||
|
|
||
| # Group devices by last segment of device ID | ||
| last_segment_groups: Dict[str, List[Any]] = {} | ||
|
|
||
| for device in vmbus_devices: | ||
| device_id = device.device_id | ||
| # Device ID format: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX | ||
| # Extract last segment after the last hyphen | ||
| id_parts = device_id.split("-") | ||
| if len(id_parts) >= 5: | ||
| last_segment = id_parts[-1].lower() | ||
| if last_segment not in last_segment_groups: | ||
| last_segment_groups[last_segment] = [] | ||
| last_segment_groups[last_segment].append(device) | ||
|
|
||
| # Find a group with exactly the GPU count | ||
| for last_segment, devices in last_segment_groups.items(): | ||
| if len(devices) == actual_gpu_count: | ||
| # all should be PCI Express pass-through devices | ||
| all_pci_passthrough = all( | ||
| "PCI Express pass-through" in device.name for device in devices | ||
| ) | ||
|
|
||
| if all_pci_passthrough: | ||
| self._log.debug( | ||
| f"Found {len(devices)} PCI Express pass-through devices " | ||
| f"with last segment '{last_segment}' matching GPU count" | ||
| ) | ||
| # Log the matched devices for debugging | ||
| for device in devices: | ||
| self._log.debug(f" GPU device: {device.device_id}") | ||
| return actual_gpu_count | ||
|
|
||
| # If no exact match, log what we found for debugging | ||
| self._log.debug( | ||
| f"No device group with last segment matches " | ||
| f"GPU count {actual_gpu_count}" | ||
| ) | ||
| for last_segment, devices in last_segment_groups.items(): | ||
| # Only log groups with PCI Express pass-through devices | ||
| pci_devices = [ | ||
| d for d in devices if "PCI Express pass-through" in d.name | ||
| ] | ||
| if pci_devices: | ||
| self._log.debug( | ||
| f" Last segment '{last_segment}': " | ||
| f"{len(pci_devices)} PCI devices" | ||
| ) | ||
|
|
||
| return 0 | ||
|
|
||
| except Exception as e: | ||
| self._log.debug(f"Last-segment grouping failed: {e}") | ||
| return 0 | ||
|
|
||
| def _get_gpu_count_from_known_list(self, vmbus_devices: List[Any]) -> int: | ||
| """ | ||
| Original method - check against known list of GPUs | ||
| """ | ||
| lsvmbus_device_count = 0 | ||
| bridge_device_count = 0 | ||
|
|
||
| lsvmbus_tool = self._node.tools[Lsvmbus] | ||
| device_list = lsvmbus_tool.get_device_channels() | ||
| for device in device_list: | ||
| for device in vmbus_devices: | ||
| for name, id_, bridge_count in NvidiaSmi.gpu_devices: | ||
| if id_ in device.device_id: | ||
| lsvmbus_device_count += 1 | ||
| bridge_device_count = bridge_count | ||
| self._log.debug(f"GPU device {name} found!") | ||
| self._log.debug( | ||
| f"GPU device {name} found using hardcoded list! " | ||
| f"Device ID: {device.device_id}" | ||
| ) | ||
| break | ||
|
|
||
| return lsvmbus_device_count - bridge_device_count | ||
|
|
@@ -156,7 +260,7 @@ def get_gpu_count_with_lspci(self) -> int: | |
|
|
||
| def get_gpu_count_with_vendor_cmd(self) -> int: | ||
| nvidiasmi = self._node.tools[NvidiaSmi] | ||
| return nvidiasmi.get_gpu_count() | ||
| return nvidiasmi.get_gpu_count(known_only=False) | ||
|
|
||
| def get_supported_driver(self) -> List[ComputeSDK]: | ||
| raise NotImplementedError() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.