From a820b3c6f3d45007bb085e554400751fa6daf117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B8=A1=E8=85=BF?= <16110130+gitoor@user.noreply.gitee.com> Date: Tue, 4 Nov 2025 17:42:37 +0800 Subject: [PATCH 1/3] [feat] Add a callback for MFU logging support --- swift/plugin/callback.py | 149 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 147 insertions(+), 2 deletions(-) diff --git a/swift/plugin/callback.py b/swift/plugin/callback.py index 3be54be959..f685cb4004 100644 --- a/swift/plugin/callback.py +++ b/swift/plugin/callback.py @@ -1,8 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import time + import numpy as np +import torch from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments -from swift.utils import get_logger +from swift.utils import get_logger, get_current_device, get_device_count logger = get_logger() @@ -28,6 +31,148 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer control.should_training_stop = True -extra_callbacks = [] +class PerfMetricsLogCallback(TrainerCallback): + """A callback for perf metrics (MFU etc) log implementation""" + + def __init__(self): + self.start_time = None + self.device_tflops = None + self.elapsed = 0.0 + self.step_start_time = None + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + from swift.utils import get_env_args + + # Top priority. Specify by ENV + tflops = get_env_args('DEVICE_TFLOPS', int, None) + device_count = max(get_device_count(), 1) + if tflops is not None: + logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]") + else: + # Run a estimating test. + dtype = kwargs.get("model").dtype + device = torch.device(get_current_device()) + logger.info(f"Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]") + tflops = self._estimate_device_tflops_by_dtype(device, dtype) + logger.info(f"Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]") + + self.device_tflops = tflops * device_count + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self.step_start_time = time.time() + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self.elapsed += time.time() - self.step_start_time + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self.start_time = time.time() + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): + total_flos = getattr(state, 'total_flos', 0) + actual_flops = total_flos / self.elapsed + theoretical_max_flops = self.device_tflops * 1e12 + mfu = actual_flops / theoretical_max_flops + logger.debug(f"Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]") + logs['MFU'] = round(mfu, 6) + + @staticmethod + def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, + dim: int = 8192): + # 默认矩阵规模 + shape = (dim, dim) + backend = device.type + if backend == "npu": + import torch_npu + + # 创建矩阵 + a = torch.randn(*shape, device=device, dtype=dtype) + b = torch.randn(*shape, device=device, dtype=dtype) + + # 预热 + for _ in range(5): + c = torch.matmul(a, b) + if backend == 'cuda': + torch.cuda.synchronize(device) + elif backend == 'npu': + torch.npu.synchronize(device) + elif backend == 'cpu': + torch.cpu.synchronize(device) + + # 进行测试 + start = time.time() + for _ in range(repeats): + c = torch.matmul(a, b) + if backend == 'cuda': + torch.cuda.synchronize(device) + elif backend == 'npu': + torch.npu.synchronize(device) + elif backend == 'cpu': + torch.cpu.synchronize(device) + end = time.time() + total_time = end - start + avg_time = total_time / repeats + + # 若测试时间过短,调整循环次数并重新测试 + if total_time < 3: + repeats = int(6 / avg_time) + start = time.time() + for _ in range(repeats): + c = torch.matmul(a, b) + if backend == 'cuda': + torch.cuda.synchronize(device) + elif backend == 'npu': + torch.npu.synchronize(device) + elif backend == 'cpu': + torch.cpu.synchronize(device) + end = time.time() + total_time = end - start + avg_time = total_time / repeats + + del a, b, c + if backend == 'cuda': + torch.cuda.empty_cache() + elif backend == 'npu': + torch.npu.empty_cache() + + tflops = (2 * dim ** 3 / avg_time) / 1e12 + print( + f"[设备 {device}] 测试总耗时:{total_time:.4f}s,平均耗时: {avg_time:.4f} s,dtype:{dtype},性能: {tflops:.4f} TFLOPS") + + return tflops + + @staticmethod + def _retrieve_flops_from_map(device): + """Retrieve theoretical FLOPS from Map. """ + + device_name = device.get_device_name() + flops = None + for name, value in device_flops_map: + if name in device_name: + flops = value + break + + return flops + + +device_flops_map = { + "GB200": 2.5e15, + "B200": 2.25e15, + "MI300X": 1336e12, + "H100": 312e12, + "H800": 312e12, + "H200": 989e12, + "A100": 312e12, + "A800": 312e12, + "L40S": 362.05e12, + "L40": 181.05e12, + "A40": 149.7e12, + "L20": 119.5e12, + "H20": 148e12, + "910B": 354e12, + "Ascend910": 354e12, + "RTX 3070 Ti": 21.75e12 +} + +extra_callbacks = [PerfMetricsLogCallback()] # This example shows a simple example of EarlyStop Callback, uncomment this to use # extra_callbacks = [EarlyStopCallback()] From 655eedbb7f5eda5f409a1a8a2fc308f009f14c6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B8=A1=E8=85=BF?= <16110130+gitoor@user.noreply.gitee.com> Date: Wed, 5 Nov 2025 09:10:55 +0800 Subject: [PATCH 2/3] [fix] Fix pre-commit check --- swift/plugin/callback.py | 55 ++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/swift/plugin/callback.py b/swift/plugin/callback.py index f685cb4004..3af6477db7 100644 --- a/swift/plugin/callback.py +++ b/swift/plugin/callback.py @@ -5,7 +5,7 @@ import torch from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments -from swift.utils import get_logger, get_current_device, get_device_count +from swift.utils import get_current_device, get_device_count, get_logger logger = get_logger() @@ -32,7 +32,7 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer class PerfMetricsLogCallback(TrainerCallback): - """A callback for perf metrics (MFU etc) log implementation""" + """An callback for perf metrics (MFU etc) log implementation""" def __init__(self): self.start_time = None @@ -50,11 +50,12 @@ def on_init_end(self, args: TrainingArguments, state: TrainerState, control: Tra logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]") else: # Run a estimating test. - dtype = kwargs.get("model").dtype + dtype = kwargs.get('model').dtype device = torch.device(get_current_device()) - logger.info(f"Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]") + logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]') tflops = self._estimate_device_tflops_by_dtype(device, dtype) - logger.info(f"Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]") + logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]') + # TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables. self.device_tflops = tflops * device_count @@ -72,16 +73,15 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerC actual_flops = total_flos / self.elapsed theoretical_max_flops = self.device_tflops * 1e12 mfu = actual_flops / theoretical_max_flops - logger.debug(f"Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]") + logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]') logs['MFU'] = round(mfu, 6) @staticmethod - def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, - dim: int = 8192): + def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192): # 默认矩阵规模 shape = (dim, dim) backend = device.type - if backend == "npu": + if backend == 'npu': import torch_npu # 创建矩阵 @@ -134,9 +134,8 @@ def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, r elif backend == 'npu': torch.npu.empty_cache() - tflops = (2 * dim ** 3 / avg_time) / 1e12 - print( - f"[设备 {device}] 测试总耗时:{total_time:.4f}s,平均耗时: {avg_time:.4f} s,dtype:{dtype},性能: {tflops:.4f} TFLOPS") + tflops = (2 * dim**3 / avg_time) / 1e12 + print(f'[设备 {device}] 测试总耗时:{total_time:.4f}s,平均耗时: {avg_time:.4f} s,dtype:{dtype},性能: {tflops:.4f} TFLOPS') return tflops @@ -155,22 +154,22 @@ def _retrieve_flops_from_map(device): device_flops_map = { - "GB200": 2.5e15, - "B200": 2.25e15, - "MI300X": 1336e12, - "H100": 312e12, - "H800": 312e12, - "H200": 989e12, - "A100": 312e12, - "A800": 312e12, - "L40S": 362.05e12, - "L40": 181.05e12, - "A40": 149.7e12, - "L20": 119.5e12, - "H20": 148e12, - "910B": 354e12, - "Ascend910": 354e12, - "RTX 3070 Ti": 21.75e12 + 'GB200': 2.5e15, + 'B200': 2.25e15, + 'MI300X': 1336e12, + 'H100': 312e12, + 'H800': 312e12, + 'H200': 989e12, + 'A100': 312e12, + 'A800': 312e12, + 'L40S': 362.05e12, + 'L40': 181.05e12, + 'A40': 149.7e12, + 'L20': 119.5e12, + 'H20': 148e12, + '910B': 354e12, + 'Ascend910': 354e12, + 'RTX 3070 Ti': 21.75e12 } extra_callbacks = [PerfMetricsLogCallback()] From 190ee31a918557fe8e2797b074929c7b72b94d18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B8=A1=E8=85=BF?= <16110130+gitoor@user.noreply.gitee.com> Date: Wed, 5 Nov 2025 10:04:26 +0800 Subject: [PATCH 3/3] [refactor] improve code style --- swift/plugin/callback.py | 51 +++++++++++++++------------------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/swift/plugin/callback.py b/swift/plugin/callback.py index 3af6477db7..00fbac32ea 100644 --- a/swift/plugin/callback.py +++ b/swift/plugin/callback.py @@ -5,7 +5,7 @@ import torch from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments -from swift.utils import get_current_device, get_device_count, get_logger +from swift.utils import get_logger logger = get_logger() @@ -35,13 +35,12 @@ class PerfMetricsLogCallback(TrainerCallback): """An callback for perf metrics (MFU etc) log implementation""" def __init__(self): - self.start_time = None self.device_tflops = None self.elapsed = 0.0 self.step_start_time = None def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - from swift.utils import get_env_args + from swift.utils import get_current_device, get_device_count, get_env_args # Top priority. Specify by ENV tflops = get_env_args('DEVICE_TFLOPS', int, None) @@ -65,9 +64,6 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): self.elapsed += time.time() - self.step_start_time - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - self.start_time = time.time() - def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): total_flos = getattr(state, 'total_flos', 0) actual_flops = total_flos / self.elapsed @@ -78,6 +74,16 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerC @staticmethod def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192): + from swift.utils.torch_utils import empty_cache + + def device_synchronize(sync_device): + if backend == 'cuda': + torch.cuda.synchronize(sync_device) + elif backend == 'npu': + torch.npu.synchronize(sync_device) + elif backend == 'cpu': + torch.cpu.synchronize(sync_device) + # 默认矩阵规模 shape = (dim, dim) backend = device.type @@ -91,23 +97,13 @@ def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, r # 预热 for _ in range(5): c = torch.matmul(a, b) - if backend == 'cuda': - torch.cuda.synchronize(device) - elif backend == 'npu': - torch.npu.synchronize(device) - elif backend == 'cpu': - torch.cpu.synchronize(device) + device_synchronize(device) # 进行测试 start = time.time() for _ in range(repeats): c = torch.matmul(a, b) - if backend == 'cuda': - torch.cuda.synchronize(device) - elif backend == 'npu': - torch.npu.synchronize(device) - elif backend == 'cpu': - torch.cpu.synchronize(device) + device_synchronize(device) end = time.time() total_time = end - start avg_time = total_time / repeats @@ -118,25 +114,16 @@ def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, r start = time.time() for _ in range(repeats): c = torch.matmul(a, b) - if backend == 'cuda': - torch.cuda.synchronize(device) - elif backend == 'npu': - torch.npu.synchronize(device) - elif backend == 'cpu': - torch.cpu.synchronize(device) + device_synchronize(device) end = time.time() total_time = end - start avg_time = total_time / repeats del a, b, c - if backend == 'cuda': - torch.cuda.empty_cache() - elif backend == 'npu': - torch.npu.empty_cache() + empty_cache() tflops = (2 * dim**3 / avg_time) / 1e12 - print(f'[设备 {device}] 测试总耗时:{total_time:.4f}s,平均耗时: {avg_time:.4f} s,dtype:{dtype},性能: {tflops:.4f} TFLOPS') - + logger.info(f'[Device {device}] Total time: {total_time:.4f}s, dtype: {dtype}, Perf: {tflops:.4f} TFLOPS') return tflops @staticmethod @@ -145,7 +132,7 @@ def _retrieve_flops_from_map(device): device_name = device.get_device_name() flops = None - for name, value in device_flops_map: + for name, value in device_flops_map.items(): if name in device_name: flops = value break @@ -174,4 +161,4 @@ def _retrieve_flops_from_map(device): extra_callbacks = [PerfMetricsLogCallback()] # This example shows a simple example of EarlyStop Callback, uncomment this to use -# extra_callbacks = [EarlyStopCallback()] +# extra_callbacks = [EarlyStopCallback()] \ No newline at end of file