Skip to content

Commit 48b96ec

Browse files
committed
Update extended attention interface
Address review comments Fix CI XPU not found
1 parent c14abd5 commit 48b96ec

File tree

6 files changed

+100
-93
lines changed

6 files changed

+100
-93
lines changed

.github/workflows/third-party-benchmarks.yml

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,34 @@ jobs:
7171

7272
- name: Setup Triton
7373
uses: ./.github/actions/setup-triton
74+
with:
75+
command: DEBUG=1 python setup.py bdist_wheel
7476

75-
- name: Install benchmarks
77+
- name: Install Triton
7678
id: install
7779
run: |
78-
cd benchmarks
79-
pip install .
80+
pip install dist/*.whl
8081
8182
- name: Install benchmark dependencies
82-
id: install_deps
83+
id: install
8384
run: |
8485
pip install transformers pandas pytest
8586
87+
cd benchmarks
88+
pip install .
89+
pip install intel-pti==0.12.2
90+
PTI_LIBS_DIR=$(python -c "import sysconfig; print(sysconfig.get_paths()['stdlib']+'/..')")
91+
# the output should contain: `libpti.so`, `libpti_metrics.so.0.12.2` and `libpti_view.so.0.12.2`
92+
ls $PTI_LIBS_DIR
93+
echo "PTI_LIBS_DIR=$PTI_LIBS_DIR" >> $GITHUB_ENV
94+
8695
- name: Create reports dir
8796
run: |
8897
mkdir reports
8998
echo "REPORTS=$PWD/reports" >> $GITHUB_ENV
9099
91100
- name: Run Liger-Kernel benchmarks
92-
if: ${{ steps.install_deps.outcome == 'success' && !cancelled() }}
101+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
93102
run: |
94103
source ./scripts/capture-hw-details.sh
95104
@@ -111,11 +120,22 @@ jobs:
111120
- name: Install SGLANG
112121
run: |
113122
git clone https://github.com/sgl-project/sglang.git
114-
pip install sglang/python[dev_xpu]
123+
cd sglang
124+
git apply ../benchmarks/third_party/sglang/sglang.patch
125+
pip install ./python[dev_xpu]
126+
127+
# Reinstallation since SGLang installation will force overrides current PyTorch and Triton
128+
- name: Reinstall PyTorch
129+
uses: ./.github/actions/setup-pytorch
130+
131+
- name: Reinstall Triton
132+
run: |
133+
pip install ./dist/*.whl
115134
116135
- name: Run SGLANG attention prefill stage benchmark
117136
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
118137
run: |
138+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
119139
cd benchmarks/third_party/sglang
120140
python prefill_attention_benchmark.py --reports $REPORTS
121141
@@ -125,6 +145,7 @@ jobs:
125145
- name: Run SGLANG attention decode stage benchmark
126146
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
127147
run: |
148+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
128149
cd benchmarks/third_party/sglang
129150
python decode_attention_benchmark.py --reports $REPORTS
130151
@@ -134,20 +155,22 @@ jobs:
134155
- name: Run SGLANG attention append stage benchmark
135156
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
136157
run: |
158+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
137159
cd benchmarks/third_party/sglang
138160
python extended_attention_benchmark.py --reports $REPORTS
139161
140162
source ../../../scripts/capture-hw-details.sh
141-
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
163+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142164
143165
- name: Run SGLANG Block FP8 GEMM benchmark
144166
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }}
145167
run: |
168+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
146169
cd benchmarks/third_party/sglang
147170
python block_fp8_gemm_benchmark.py --reports $REPORTS
148171
149172
source ../../../scripts/capture-hw-details.sh
150-
python ../../../scripts/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
173+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
151174
152175
- name: Upload benchmark reports
153176
if: ${{ steps.install.outcome == 'success' && !cancelled() }}

benchmarks/third_party/sglang/decode_attention_benchmark.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
4747
@benchmark_suit.perf_report(
4848
benchmark_suit.Benchmark(
4949
# argument names to use as an x-axis for the plot
50-
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
50+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE'],
5151
x_vals=[ #
52-
[bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
52+
[bs, 1024 + 64, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
5353
] + [ #
54-
[bs, [1024, 64], 32, 32, 96, 'fwd', False] for bs in [1, 16, 32, 64, 128]
54+
[bs, 1024 + 64, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128]
5555
] + [ #
56-
[bs, [1024, 64], 28, 4, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
56+
[bs, 1024 + 64, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
5757
],
5858
line_arg='provider',
5959
# argument name whose value corresponds to a different line in the plot
@@ -72,27 +72,22 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
7272
# name for the plot. Used also as a file name for saving the plot.
7373
args={},
7474
))
75-
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
75+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, provider):
7676
torch.manual_seed(0)
7777
dtype = torch.bfloat16
7878
quantiles = [0.5, 0.0, 1.0]
79-
N_CTX = sum(SEQ_LENS)
79+
N_CTX = SEQ_LENS
8080

8181
q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args(
8282
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
8383

84-
if provider == 'triton':
84+
if provider == 'triton' and MODE == 'fwd':
8585
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse,
8686
num_kv_splits, max_kv_splits, sm_scale)
87-
88-
# TODO: decode attention should have the validation function
89-
if VALIDATE:
90-
raise NotImplementedError('Validation is not implemented for decode stage')
91-
9287
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
9388

9489
else:
95-
raise NotImplementedError(f'Unsupported provider {provider}')
90+
raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}')
9691

9792
tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * (1e-12) / (ms * 1e-3)
9893
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)

benchmarks/third_party/sglang/extended_attention_benchmark.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import torch
22
from sglang.srt.layers.attention.triton_ops.extend_attention import (
3-
extend_attention_fwd,
4-
redundant_attention,
5-
)
3+
extend_attention_fwd, )
64
import triton_kernels_benchmark as benchmark_suit
75

86

9-
def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
7+
# pylint: disable=unused-argument
8+
def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device):
109

11-
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device)
12-
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device)
10+
b_seq_len_prefix = torch.full((B, ), PREFIX_LEN, dtype=torch.int32, device=device)
11+
b_seq_len_extend = torch.full((B, ), Q_LEN, dtype=torch.int32, device=device)
1312
b_seq_len = b_seq_len_prefix + b_seq_len_extend
14-
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
1513

16-
b_req_idx = torch.arange(B, dtype=torch.int32, device=device)
1714
b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device)
1815
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
1916
b_start_loc_extend = torch.zeros((B, ), dtype=torch.int32, device=device)
@@ -45,32 +42,30 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
4542
device=device).normal_(mean=0.1, std=0.2)
4643

4744
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
48-
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
4945

5046
b_seq_len_extend = b_seq_len - b_seq_len_prefix
5147
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
5248
qo_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device)
5349
qo_indptr[1:B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
5450

5551
params = []
56-
params.append((q_extend, k_extend, v_extend, o_extend, o_redundant))
52+
params.append((q_extend, k_extend, v_extend, o_extend))
5753
params.append((k_buffer, v_buffer))
5854
params.append((qo_indptr, kv_indptr, kv_indices, max_len_extend))
59-
params.append((b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch))
6055
return params
6156

6257

6358
# pylint: disable=unused-argument
6459
@benchmark_suit.perf_report(
6560
benchmark_suit.Benchmark(
6661
# argument names to use as an x-axis for the plot
67-
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
62+
x_names=['B', 'Q_LEN', 'PREFIX_LEN', 'KV_LEN', 'H_Q', 'H_KV', 'D', 'MODE'],
6863
x_vals=[ #
69-
[bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
64+
[bs, 512, 1024 + 128, 512, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
7065
] + [ #
71-
[bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128]
66+
[bs, 512, 1024 + 128, 512, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128]
7267
] + [ #
73-
[bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
68+
[bs, 512, 1024 + 128, 512, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
7469
],
7570
line_arg='provider',
7671
# argument name whose value corresponds to a different line in the plot
@@ -89,41 +84,26 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
8984
# name for the plot. Used also as a file name for saving the plot.
9085
args={},
9186
))
92-
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
87+
def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, provider):
9388
torch.manual_seed(0)
9489

9590
dtype = torch.bfloat16
96-
N_CTX = sum(SEQ_LENS)
9791

98-
params = gen_args(B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
99-
q_extend, k_extend, v_extend, o_extend, o_redundant = params[0]
92+
params = gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, 'xpu')
93+
q_extend, k_extend, v_extend, o_extend = params[0]
10094
k_buffer, v_buffer = params[1]
10195
qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2]
102-
b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch = params[3]
10396
custom_mask = None
10497
mask_indptr = None
10598

10699
quantiles = [0.5, 0.0, 1.0]
107-
if provider == 'triton':
108-
109-
def triton_fn():
110-
extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr,
111-
kv_indices, custom_mask, mask_indptr, max_len_extend)
112-
return o_extend
113-
114-
if VALIDATE:
115-
116-
def refer_fn():
117-
redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len,
118-
b_seq_len_prefix, max_len_in_batch)
119-
return o_redundant
120-
121-
benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer')
122-
100+
if provider == 'triton' and MODE == 'fwd':
101+
triton_fn = lambda: extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr,
102+
kv_indptr, kv_indices, custom_mask, True, mask_indptr, max_len_extend)
123103
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
124104

125105
else:
126-
raise NotImplementedError(f'Unsupported provider {provider}')
106+
raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}')
127107

128108
N_CTX_TOTAL = k_buffer.shape[0]
129109
N_CTX_EXTEND = k_extend.shape[0]

benchmarks/third_party/sglang/prefill_attention_benchmark.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
9-
max_seq_len = max(SEQ_LENS)
10-
N_CTX = sum(SEQ_LENS)
9+
max_seq_len = SEQ_LENS
10+
N_CTX = SEQ_LENS
1111

1212
# Create random input tensors
1313
q = torch.randn((B * N_CTX, H_Q, D), device=device, dtype=dtype)
@@ -16,8 +16,8 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
1616
o = torch.zeros((B * N_CTX, H_Q, D), device=device, dtype=dtype)
1717

1818
# Create b_start_loc and b_seq_len tensors
19-
b_start_loc = torch.tensor([0, SEQ_LENS[0]], device=device)
20-
b_seq_len = torch.tensor(SEQ_LENS, device=device)
19+
b_start_loc = torch.tensor([0, SEQ_LENS], device=device)
20+
b_seq_len = torch.tensor([SEQ_LENS], device=device)
2121

2222
return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
2323

@@ -26,13 +26,13 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
2626
@benchmark_suit.perf_report(
2727
benchmark_suit.Benchmark(
2828
# argument names to use as an x-axis for the plot
29-
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'VALIDATE'],
29+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE'],
3030
x_vals=[ #
31-
[bs, [1024], 32, 8, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
31+
[bs, 1024, 32, 8, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
3232
] + [ #
33-
[bs, [1024], 32, 32, 96, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
33+
[bs, 1024, 32, 32, 96, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
3434
] + [ #
35-
[bs, [1024], 28, 4, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
35+
[bs, 1024, 28, 4, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
3636
],
3737
line_arg='provider',
3838
# argument name whose value corresponds to a different line in the plot
@@ -51,43 +51,21 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
5151
# name for the plot. Used also as a file name for saving the plot.
5252
args={},
5353
))
54-
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, VALIDATE, provider):
54+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, provider):
5555
torch.manual_seed(0)
5656
dtype = torch.bfloat16
57-
N_CTX = sum(SEQ_LENS)
5857

5958
q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu')
6059

6160
quantiles = [0.5, 0.0, 1.0]
62-
if provider == 'triton':
63-
61+
if provider == 'triton' and MODE == 'fwd':
6462
triton_fn = lambda: context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=CAUSAL)
65-
66-
if VALIDATE:
67-
# FIXME: torch sdpa does not support different H_Q and H_KV
68-
cu_seq_lens = [0] * (len(SEQ_LENS) + 1)
69-
for i, seq_len in enumerate(SEQ_LENS):
70-
cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len
71-
72-
for i in range(len(SEQ_LENS)):
73-
start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
74-
o_torch = torch.nn.functional.scaled_dot_product_attention(
75-
q[start:end].permute(1, 0, 2),
76-
k[start:end].permute(1, 0, 2),
77-
v[start:end].permute(1, 0, 2),
78-
is_causal=CAUSAL,
79-
).permute(1, 0, 2)
80-
81-
cos_sim = torch.nn.functional.cosine_similarity(o[start:end].flatten(), o_torch.flatten(), dim=0)
82-
assert cos_sim.item() > 1 - (1e-5)
83-
assert torch.allclose(o[start:end], o_torch, atol=1e-2)
84-
8563
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
8664
quantiles=quantiles)
87-
8865
else:
89-
raise NotImplementedError(f'Unsupported provider {provider}')
66+
raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}')
9067

68+
N_CTX = SEQ_LENS
9169
tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * N_CTX * D * (1e-12) / (ms * 1e-3)
9270
gbps = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * 2 * (1e-9) / (ms * 1e-3)
9371

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
2+
index 884e715f..580e2364 100644
3+
--- a/python/sglang/srt/utils.py
4+
+++ b/python/sglang/srt/utils.py
5+
@@ -77,12 +77,20 @@ from torch.func import functional_call
6+
from torch.library import Library
7+
from torch.profiler import ProfilerActivity, profile, record_function
8+
from torch.utils._contextlib import _DecoratorContextManager
9+
-from triton.runtime.cache import (
10+
- FileCacheManager,
11+
- default_cache_dir,
12+
- default_dump_dir,
13+
- default_override_dir,
14+
-)
15+
+try:
16+
+ from triton.runtime.cache import (
17+
+ FileCacheManager,
18+
+ default_cache_dir,
19+
+ default_dump_dir,
20+
+ default_override_dir,
21+
+ )
22+
+except ImportError:
23+
+ from triton.runtime.cache import FileCacheManager
24+
+ from triton.knobs import cache as tt_cache
25+
+
26+
+ default_cache_dir = lambda: tt_cache.dir
27+
+ default_dump_dir = lambda: tt_cache.dump_dir
28+
+ default_override_dir = lambda: tt_cache.override_dir
29+
30+
logger = logging.getLogger(__name__)
31+

benchmarks/triton_kernels_benchmark/build_report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def build_report(args: PassedArgs, results_df: Optional[pd.DataFrame] = None):
9090
df[p] = df[p].astype(int)
9191
df_results["params"] = [json.dumps(j) for j in df[[*param_cols, "MASK"]].to_dict("records")]
9292
else:
93-
df_results["params"] = [json.dumps(j) for j in df[param_cols].astype(str).to_dict("records")]
93+
df_results["params"] = [json.dumps(j) for j in df[param_cols].astype(int).to_dict("records")]
9494
df_results["tflops"] = df[args.tflops_col]
9595
if hbm_col is not None:
9696
df_results["hbm_gbs"] = df[hbm_col]

0 commit comments

Comments
 (0)