@@ -1064,7 +1064,7 @@ def __init__(
10641064 self .allow_flashinfer = _nvfp4 .allow_flashinfer
10651065 self .use_marlin = _nvfp4 .use_marlin
10661066 self .flashinfer_moe_backend = None
1067-
1067+ self . _cache_permute_indices : dict [ torch . Size , torch . Tensor ] = {}
10681068 if self .allow_flashinfer :
10691069 self .flashinfer_moe_backend = get_flashinfer_moe_backend ()
10701070 logger .info_once (
@@ -1197,19 +1197,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
11971197 weight_loader = weight_loader )
11981198 layer .register_parameter ("w2_input_scale" , w2_input_scale )
11991199
1200- def prepare_static_weight_layouts_for_trtllm_moe (
1200+ def prepare_static_weights_for_trtllm_fp4_moe (
12011201 self ,
1202- gemm1_weights : torch .Tensor ,
1203- gemm2_weights : torch .Tensor ,
1204- gemm1_scales_linear_fp4_bytes : torch .Tensor ,
1205- gemm2_scales_linear_fp4_bytes : torch .Tensor ,
1206- hidden_size : int ,
1207- intermediate_size : int ,
1208- num_experts : int ,
1209- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
1202+ # args_dequant,
1203+ # args,
1204+ gemm1_weights ,
1205+ gemm2_weights ,
1206+ gemm1_scales_linear_fp4_bytes ,
1207+ gemm2_scales_linear_fp4_bytes ,
1208+ hidden_size ,
1209+ intermediate_size ,
1210+ num_experts ,
1211+ ):
1212+ from flashinfer import nvfp4_block_scale_interleave
1213+ from flashinfer .fused_moe .core import (
1214+ _maybe_get_cached_w2_permute_indices ,
1215+ _maybe_get_cached_w3_w1_permute_indices )
12101216 """Prepare quantized weights for kernel (done offline with weights)."""
1211- from flashinfer import (reorder_rows_for_gated_act_gemm ,
1212- shuffle_matrix_a , shuffle_matrix_sf_a )
12131217 epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
12141218
12151219 # Convert quantized weights to proper formats
@@ -1227,48 +1231,54 @@ def prepare_static_weight_layouts_for_trtllm_moe(
12271231 intermediate_size //
12281232 16 ) # fp8 scaling factors
12291233
1230- # Reorder rows of W1 and scales for fused gated activation
1231- gemm1_weights_fp4_interleaved = []
1232- gemm1_scales_fp4_interleaved = []
1233- for i in range (num_experts ):
1234- gemm1_weights_fp4_interleaved .append (
1235- reorder_rows_for_gated_act_gemm (gemm1_weights_fp4 [i ].clone ()))
1236- gemm1_scales_fp4_interleaved .append (
1237- reorder_rows_for_gated_act_gemm (
1238- gemm1_scales_linear_fp4 [i ].clone ()))
1239-
1240- # Stack weights and scales for all experts
1241- gemm1_weights_fp4_interleaved = torch .stack (
1242- gemm1_weights_fp4_interleaved ).reshape (num_experts ,
1243- 2 * intermediate_size ,
1244- hidden_size // 2 )
1245- gemm1_scales_fp4_interleaved = torch .stack (
1246- gemm1_scales_fp4_interleaved ).reshape (num_experts ,
1247- 2 * intermediate_size ,
1248- hidden_size // 16 )
1249-
1250- # Shuffle weights and scaling factors for transposed mma output
12511234 gemm1_weights_fp4_shuffled = []
12521235 gemm1_scales_fp4_shuffled = []
12531236 gemm2_weights_fp4_shuffled = []
12541237 gemm2_scales_fp4_shuffled = []
12551238 for i in range (num_experts ):
1256- gemm1_weights_fp4_shuffled .append (
1257- shuffle_matrix_a (
1258- gemm1_weights_fp4_interleaved [i ].view (torch .uint8 ),
1259- epilogue_tile_m ))
1239+ # Calculate the permute indices for the following:
1240+ # 1. Reorder rows of W1 and scales for fused gated activation
1241+ # 2. Shuffle weights and scaling factors for transposed mma output
1242+ # for both w3_w1 and w2 weights and scale factors
1243+ permute_indices = _maybe_get_cached_w3_w1_permute_indices (
1244+ self ._cache_permute_indices ,
1245+ gemm1_weights_fp4 [i ].view (torch .uint8 ),
1246+ epilogue_tile_m ,
1247+ )
1248+ gemm1_weights_fp4_shuffled .append (gemm1_weights_fp4 [i ].view (
1249+ torch .uint8 )[permute_indices .to (
1250+ gemm1_weights_fp4 .device )].contiguous ())
1251+
1252+ permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices (
1253+ self ._cache_permute_indices ,
1254+ gemm1_scales_linear_fp4 [i ].view (torch .uint8 ),
1255+ epilogue_tile_m ,
1256+ num_elts_per_sf = 16 ,
1257+ )
12601258 gemm1_scales_fp4_shuffled .append (
1261- shuffle_matrix_sf_a (
1262- gemm1_scales_fp4_interleaved [i ].view (torch .uint8 ),
1263- epilogue_tile_m ))
1264-
1265- gemm2_weights_fp4_shuffled .append (
1266- shuffle_matrix_a (gemm2_weights_fp4 [i ].view (torch .uint8 ),
1267- epilogue_tile_m ))
1259+ nvfp4_block_scale_interleave (gemm1_scales_linear_fp4 [i ].view (
1260+ torch .uint8 )[permute_sf_indices .to (
1261+ gemm1_scales_linear_fp4 .device )].contiguous ()))
1262+
1263+ permute_indices = _maybe_get_cached_w2_permute_indices (
1264+ self ._cache_permute_indices ,
1265+ gemm2_weights_fp4 [i ].view (torch .uint8 ),
1266+ epilogue_tile_m ,
1267+ )
1268+ gemm2_weights_fp4_shuffled .append (gemm2_weights_fp4 [i ].view (
1269+ torch .uint8 )[permute_indices .to (
1270+ gemm2_weights_fp4 .device )].contiguous ())
1271+
1272+ permute_sf_indices = _maybe_get_cached_w2_permute_indices (
1273+ self ._cache_permute_indices ,
1274+ gemm2_scales_linear_fp4 [i ].view (torch .uint8 ),
1275+ epilogue_tile_m ,
1276+ num_elts_per_sf = 16 ,
1277+ )
12681278 gemm2_scales_fp4_shuffled .append (
1269- shuffle_matrix_sf_a (
1270- gemm2_scales_linear_fp4 [ i ]. view ( torch .uint8 ),
1271- epilogue_tile_m ))
1279+ nvfp4_block_scale_interleave ( gemm2_scales_linear_fp4 [ i ]. view (
1280+ torch .uint8 )[ permute_sf_indices . to (
1281+ gemm2_scales_linear_fp4 . device )]. contiguous () ))
12721282
12731283 # Stack weights for all experts
12741284 gemm1_weights_fp4_shuffled = torch .stack (gemm1_weights_fp4_shuffled )
@@ -1283,8 +1293,12 @@ def prepare_static_weight_layouts_for_trtllm_moe(
12831293 torch .stack (gemm2_scales_fp4_shuffled ).view (
12841294 torch .float8_e4m3fn ).reshape (num_experts , hidden_size ,
12851295 intermediate_size // 16 ))
1286- return (gemm1_weights_fp4_shuffled , gemm1_scales_fp4_shuffled ,
1287- gemm2_weights_fp4_shuffled , gemm2_scales_fp4_shuffled )
1296+ return (
1297+ gemm1_weights_fp4_shuffled ,
1298+ gemm1_scales_fp4_shuffled ,
1299+ gemm2_weights_fp4_shuffled ,
1300+ gemm2_scales_fp4_shuffled ,
1301+ )
12881302
12891303 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
12901304 # GEMM 1 processing
@@ -1334,9 +1348,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13341348 if self .allow_flashinfer and \
13351349 self .flashinfer_moe_backend == FlashinferMoeBackend .TENSORRT_LLM :
13361350 # Prepare static weights for TRT-LLM kernel
1351+ # alternate: prepare_static_weight_layouts_for_trtllm_moe
13371352 (gemm1_weights_fp4_shuffled , gemm1_scales_fp4_shuffled ,
13381353 gemm2_weights_fp4_shuffled , gemm2_scales_fp4_shuffled
1339- ) = self .prepare_static_weight_layouts_for_trtllm_moe (
1354+ ) = self .prepare_static_weights_for_trtllm_fp4_moe (
13401355 layer .w13_weight ,
13411356 layer .w2_weight ,
13421357 layer .w13_weight_scale ,
@@ -1345,6 +1360,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13451360 layer .w13_weight .size (- 2 ) // 2 , # intermediate_size
13461361 layer .w13_weight .size (0 ), # num_experts
13471362 )
1363+ logger .debug_once ("Finished shuffling weights for TRT-LLM MOE" )
13481364
13491365 layer .gemm1_weights_fp4_shuffled = Parameter (
13501366 gemm1_weights_fp4_shuffled , requires_grad = False )
0 commit comments