diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index a17ae3e1e9..18f35113e3 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -843,3 +843,4 @@ qwen2_5_omni除了包含qwen2_5_vl和qwen2_audio的模型特定参数外,还 - VLLM_USE_V1: 用于切换vLLM使用V0/V1版本。 - SWIFT_TIMEOUT: (ms-swift>=3.10) 若多模态数据集中存在图像URL,该参数用于控制获取图片的timeout,默认为20s。 - ROOT_IMAGE_DIR: (ms-swift>=3.8) 图像(多模态)资源的根目录。通过设置该参数,可以在数据集中使用相对于 `ROOT_IMAGE_DIR` 的相对路径。默认情况下,是相对于运行目录的相对路径。 +- SWIFT_SINGLE_DEVICE_MODE: (ms-swift>=3.10) 单设备模式,在此模式下,所有进程只能看到一个设备,目前用于兼容PPU设备 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index d4e0b9319d..5baaaa5d78 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -868,3 +868,4 @@ The meanings of the following parameters can be found in the example code [here] - VLLM_USE_V1: Used to switch between V0 and V1 versions of vLLM. - SWIFT_TIMEOUT: (ms-swift >= 3.10) If the multimodal dataset contains image URLs, this parameter controls the timeout for fetching images, defaulting to 20 seconds. - ROOT_IMAGE_DIR: (ms-swift>=3.8) The root directory for image (multimodal) resources. By setting this parameter, relative paths in the dataset can be interpreted relative to `ROOT_IMAGE_DIR`. By default, paths are relative to the current working directory. +- SWIFT_SINGLE_DEVICE_MODE: (ms-swift>=3.10) Single device mode. In this mode, all processes can only see one device. Currently used for compatibility with PPU devices. diff --git a/swift/cli/pt.py b/swift/cli/pt.py index 1ca2aabd8a..60477214b1 100644 --- a/swift/cli/pt.py +++ b/swift/cli/pt.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import pt_main if __name__ == '__main__': + from swift.cli.utils import try_use_single_device_mode + try_use_single_device_mode() + from swift.llm import pt_main pt_main() diff --git a/swift/cli/rlhf.py b/swift/cli/rlhf.py index 4f0fd6a0ab..5d8400fc5a 100644 --- a/swift/cli/rlhf.py +++ b/swift/cli/rlhf.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import rlhf_main if __name__ == '__main__': + from swift.cli.utils import try_use_single_device_mode + try_use_single_device_mode() + from swift.llm import rlhf_main rlhf_main() diff --git a/swift/cli/sft.py b/swift/cli/sft.py index 4e780e141b..27076381da 100644 --- a/swift/cli/sft.py +++ b/swift/cli/sft.py @@ -11,6 +11,8 @@ def try_init_unsloth(): if __name__ == '__main__': + from swift.cli.utils import try_use_single_device_mode + try_use_single_device_mode() try_init_unsloth() from swift.ray import try_init_ray try_init_ray() diff --git a/swift/cli/utils.py b/swift/cli/utils.py new file mode 100644 index 0000000000..f0c8002ff2 --- /dev/null +++ b/swift/cli/utils.py @@ -0,0 +1,13 @@ +import os + + +def try_use_single_device_mode(): + if os.environ.get('SWIFT_SINGLE_DEVICE_MODE', '0') == '1': + visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') + local_rank = os.environ.get('LOCAL_RANK') + if local_rank is None or not visible_devices: + return + visible_devices = visible_devices.split(',') + visible_device = visible_devices[int(local_rank)] + os.environ['CUDA_VISIBLE_DEVICES'] = str(visible_device) + os.environ['LOCAL_RANK'] = '0' diff --git a/swift/utils/env.py b/swift/utils/env.py index 3553492e97..bfc5c1d07f 100644 --- a/swift/utils/env.py +++ b/swift/utils/env.py @@ -66,10 +66,13 @@ def is_mp() -> bool: from swift.utils import get_device_count n_gpu = get_device_count() local_world_size = get_dist_setting()[3] - assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}' - if n_gpu // local_world_size >= 2: - return True - return False + if os.environ.get('SWIFT_SINGLE_DEVICE_MODE', '0') != '1': + assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}' + if n_gpu // local_world_size >= 2: + return True + return False + else: + return False def is_mp_ddp() -> bool: