Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 133 additions & 2 deletions swift/plugin/callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import time

import numpy as np
import torch
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments

from swift.utils import get_logger
Expand Down Expand Up @@ -28,6 +31,134 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer
control.should_training_stop = True


extra_callbacks = []
class PerfMetricsLogCallback(TrainerCallback):
"""An callback for perf metrics (MFU etc) log implementation"""

def __init__(self):
self.device_tflops = None
self.elapsed = 0.0
self.step_start_time = None

def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
from swift.utils import get_current_device, get_device_count, get_env_args

# Top priority. Specify by ENV
tflops = get_env_args('DEVICE_TFLOPS', int, None)
device_count = max(get_device_count(), 1)
if tflops is not None:
logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]")
else:
# Run a estimating test.
dtype = kwargs.get('model').dtype
device = torch.device(get_current_device())
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
# TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables.

self.device_tflops = tflops * device_count

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.step_start_time = time.time()

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.elapsed += time.time() - self.step_start_time

def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
total_flos = getattr(state, 'total_flos', 0)
actual_flops = total_flos / self.elapsed
theoretical_max_flops = self.device_tflops * 1e12
mfu = actual_flops / theoretical_max_flops
logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]')
logs['MFU'] = round(mfu, 6)

@staticmethod
def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192):
from swift.utils.torch_utils import empty_cache

def device_synchronize(sync_device):
if backend == 'cuda':
torch.cuda.synchronize(sync_device)
elif backend == 'npu':
torch.npu.synchronize(sync_device)
elif backend == 'cpu':
torch.cpu.synchronize(sync_device)

# 默认矩阵规模
shape = (dim, dim)
backend = device.type
if backend == 'npu':
import torch_npu

# 创建矩阵
a = torch.randn(*shape, device=device, dtype=dtype)
b = torch.randn(*shape, device=device, dtype=dtype)

# 预热
for _ in range(5):
c = torch.matmul(a, b)
device_synchronize(device)

# 进行测试
start = time.time()
for _ in range(repeats):
c = torch.matmul(a, b)
device_synchronize(device)
end = time.time()
total_time = end - start
avg_time = total_time / repeats

# 若测试时间过短,调整循环次数并重新测试
if total_time < 3:
repeats = int(6 / avg_time)
start = time.time()
for _ in range(repeats):
c = torch.matmul(a, b)
device_synchronize(device)
end = time.time()
total_time = end - start
avg_time = total_time / repeats

del a, b, c
empty_cache()

tflops = (2 * dim**3 / avg_time) / 1e12
logger.info(f'[Device {device}] Total time: {total_time:.4f}s, dtype: {dtype}, Perf: {tflops:.4f} TFLOPS')
return tflops

@staticmethod
def _retrieve_flops_from_map(device):
"""Retrieve theoretical FLOPS from Map. """

device_name = device.get_device_name()
flops = None
for name, value in device_flops_map.items():
if name in device_name:
flops = value
break

return flops


device_flops_map = {
'GB200': 2.5e15,
'B200': 2.25e15,
'MI300X': 1336e12,
'H100': 312e12,
'H800': 312e12,
'H200': 989e12,
'A100': 312e12,
'A800': 312e12,
'L40S': 362.05e12,
'L40': 181.05e12,
'A40': 149.7e12,
'L20': 119.5e12,
'H20': 148e12,
'910B': 354e12,
'Ascend910': 354e12,
'RTX 3070 Ti': 21.75e12
}

extra_callbacks = [PerfMetricsLogCallback()]
# This example shows a simple example of EarlyStop Callback, uncomment this to use
# extra_callbacks = [EarlyStopCallback()]
# extra_callbacks = [EarlyStopCallback()]