Skip to content

Commit 59871e3

Browse files
authored
Update qMoE spec to support block quantization (microsoft#25641)
Update operator spec to support block quantization in qMoE. Implementation will come later.
1 parent 14ca6df commit 59871e3

File tree

2 files changed

+112
-48
lines changed

2 files changed

+112
-48
lines changed

docs/ContribOperators.md

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3121,13 +3121,13 @@ This version of the operator has been available since version 1 of the 'com.micr
31213121

31223122
<dl>
31233123
<dt><tt>input</tt> : T</dt>
3124-
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
3124+
<dd>2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
31253125
<dt><tt>router_probs</tt> : T</dt>
3126-
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
3126+
<dd>2D input tensor with shape (num_tokens, num_experts)</dd>
31273127
<dt><tt>fc1_experts_weights</tt> : T</dt>
3128-
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu</dd>
3128+
<dd>3D input tensor with shape (num_experts, fusion_size * inter_size, hidden_size), where fusion_size is 2 for fused swiglu, and 1 otherwise</dd>
31293129
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
3130-
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
3130+
<dd>2D optional input tensor with shape (num_experts, fusion_size * inter_size)</dd>
31313131
<dt><tt>fc2_experts_weights</tt> : T</dt>
31323132
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
31333133
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
@@ -3142,7 +3142,7 @@ This version of the operator has been available since version 1 of the 'com.micr
31423142

31433143
<dl>
31443144
<dt><tt>output</tt> : T</dt>
3145-
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
3145+
<dd>2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
31463146
</dl>
31473147

31483148
#### Type Constraints
@@ -4532,7 +4532,23 @@ This version of the operator has been available since version 1 of the 'com.micr
45324532

45334533
### <a name="com.microsoft.QMoE"></a><a name="com.microsoft.qmoe">**com.microsoft.QMoE**</a>
45344534

4535-
Quantized MoE
4535+
Quantized mixture of experts (MoE).
4536+
4537+
Only weights are quantized with symmetric quantization.
4538+
The quantized weights are stored in column major order per expert.
4539+
The quantization block size can be specified. If not provided, column wise quantization is used.
4540+
4541+
The SwiGLU (Swish-Gated Linear Unit) activation function is like:
4542+
g = xW + b
4543+
l = xV + c
4544+
G = clamp(g, max=limit)
4545+
L = clamp(l, min=-limit, max=limit)
4546+
swiglu = G * sigmoid(alpha * G) * (L + beta)
4547+
where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters.
4548+
When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs.
4549+
When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size.
4550+
When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row.
4551+
45364552

45374553
#### Version
45384554

@@ -4547,6 +4563,8 @@ This version of the operator has been available since version 1 of the 'com.micr
45474563
<dd>Beta parameter used in activation function.</dd>
45484564
<dt><tt>activation_type</tt> : string</dt>
45494565
<dd>Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu</dd>
4566+
<dt><tt>block_size</tt> : int</dt>
4567+
<dd>Size of each quantization block along the K (input feature) dimension. Must be power of two and ≥ 16 (e.g., 16, 32, 64, 128). If provided, both hidden_size and inter_size must be divisible by the block size. Otherwise, there is no blocking and a whole column shares one scaling factor. </dd>
45504568
<dt><tt>expert_weight_bits</tt> : int</dt>
45514569
<dd>Number of bits used in quantized weights. Default is 4 bits</dd>
45524570
<dt><tt>k</tt> : int</dt>
@@ -4565,34 +4583,34 @@ This version of the operator has been available since version 1 of the 'com.micr
45654583

45664584
<dl>
45674585
<dt><tt>input</tt> : T</dt>
4568-
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
4586+
<dd>2D tensor with shape (num_tokens, hidden_size), or 3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
45694587
<dt><tt>router_probs</tt> : T</dt>
4570-
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
4588+
<dd>2D tensor with shape (num_tokens, num_experts)</dd>
45714589
<dt><tt>fc1_experts_weights</tt> : T1</dt>
4572-
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, inter_size, hidden_size / 2) for 4 bits. For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.</dd>
4590+
<dd>3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / pack_size), The fusion_size is 2 for fused swiglu, or 1 otherwise. The pack_size is 8 / expert_weight_bits.</dd>
45734591
<dt><tt>fc1_scales</tt> : T2</dt>
4574-
<dd>2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
4592+
<dd>2D tensor with shape (num_experts, fusion_size * inter_size), or 3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / block_size) when block_size is provided.</dd>
45754593
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
4576-
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
4594+
<dd>2D optional tensor with shape (num_experts, fusion_size * inter_size)</dd>
45774595
<dt><tt>fc2_experts_weights</tt> : T1</dt>
4578-
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2) for 4 bits</dd>
4596+
<dd>3D tensor with shape (num_experts, hidden_size, inter_size / pack_size)</dd>
45794597
<dt><tt>fc2_scales</tt> : T2</dt>
4580-
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
4598+
<dd>2D tensor with shape (num_experts, hidden_size), or 3D tensor with shape (num_experts, hidden_size, inter_size / block_size) when block_size is provided.</dd>
45814599
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
4582-
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
4600+
<dd>2D optional tensor with shape (num_experts, hidden_size)</dd>
45834601
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
4584-
<dd>3D optional input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
4602+
<dd>3D optional tensor with shape (num_experts, inter_size, hidden_size / pack_size)</dd>
45854603
<dt><tt>fc3_scales</tt> (optional) : T2</dt>
4586-
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
4604+
<dd>2D optional tensor with shape (num_experts, inter_size), or 3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size) when block_size is provided.</dd>
45874605
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
4588-
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
4606+
<dd>2D optional tensor with shape (num_experts, inter_size)</dd>
45894607
</dl>
45904608

45914609
#### Outputs
45924610

45934611
<dl>
45944612
<dt><tt>output</tt> : T</dt>
4595-
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
4613+
<dd>output tensor with same shape of input</dd>
45964614
</dl>
45974615

45984616
#### Type Constraints

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,22 +1412,41 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
14121412
.Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast<int64_t>(1))
14131413
.Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast<int64_t>(0))
14141414
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast<int64_t>(0))
1415-
.Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
1416-
.Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
1417-
.Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu", "T")
1418-
.Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional)
1415+
.Input(0, "input", "2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
1416+
.Input(1, "router_probs", "2D input tensor with shape (num_tokens, num_experts)", "T")
1417+
.Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, fusion_size * inter_size, hidden_size), where fusion_size is 2 for fused swiglu, and 1 otherwise", "T")
1418+
.Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, fusion_size * inter_size)", "T", OpSchema::Optional)
14191419
.Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T")
14201420
.Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional)
14211421
.Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, inter_size, hidden_size)", "T", OpSchema::Optional)
14221422
.Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
1423-
.Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
1423+
.Output(0, "output", "2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
14241424
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
14251425
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
14261426

1427+
constexpr const char* qMoE_ver1_doc = R"DOC(
1428+
Quantized mixture of experts (MoE).
1429+
1430+
Only weights are quantized with symmetric quantization.
1431+
The quantized weights are stored in column major order per expert.
1432+
The quantization block size can be specified. If not provided, column wise quantization is used.
1433+
1434+
The SwiGLU (Swish-Gated Linear Unit) activation function is like:
1435+
g = xW + b
1436+
l = xV + c
1437+
G = clamp(g, max=limit)
1438+
L = clamp(l, min=-limit, max=limit)
1439+
swiglu = G * sigmoid(alpha * G) * (L + beta)
1440+
where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters.
1441+
When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs.
1442+
When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size.
1443+
When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row.
1444+
)DOC";
1445+
14271446
ONNX_MS_OPERATOR_SET_SCHEMA(
14281447
QMoE, 1,
14291448
OpSchema()
1430-
.SetDoc("Quantized MoE")
1449+
.SetDoc(qMoE_ver1_doc)
14311450
.Attr("activation_type",
14321451
"Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu",
14331452
AttributeProto::STRING,
@@ -1440,63 +1459,90 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
14401459
"Whether to normalize routing weights",
14411460
AttributeProto::INT,
14421461
static_cast<int64_t>(0))
1443-
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast<int64_t>(0))
1462+
.Attr("use_sparse_mixer",
1463+
"Whether to use sparse mixer",
1464+
AttributeProto::INT,
1465+
static_cast<int64_t>(0))
14441466
.Attr("expert_weight_bits",
14451467
"Number of bits used in quantized weights. Default is 4 bits",
14461468
AttributeProto::INT,
14471469
static_cast<int64_t>(4))
1448-
.Attr("swiglu_fusion", "0: not fused, 1: fused and interleaved. 2: fused and not interleaved.", AttributeProto::INT, static_cast<int64_t>(0))
1449-
.Attr("swiglu_limit", "The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.", AttributeProto::FLOAT, OPTIONAL_VALUE)
1450-
.Attr("activation_alpha", "Alpha parameter used in activation function.", AttributeProto::FLOAT, 1.0f)
1451-
.Attr("activation_beta", "Beta parameter used in activation function.", AttributeProto::FLOAT, 0.0f)
1470+
.Attr("swiglu_fusion",
1471+
"0: not fused, 1: fused and interleaved. 2: fused and not interleaved.",
1472+
AttributeProto::INT,
1473+
static_cast<int64_t>(0))
1474+
.Attr("swiglu_limit",
1475+
"The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.",
1476+
AttributeProto::FLOAT,
1477+
OPTIONAL_VALUE)
1478+
.Attr("activation_alpha",
1479+
"Alpha parameter used in activation function.",
1480+
AttributeProto::FLOAT, 1.0f)
1481+
.Attr("activation_beta",
1482+
"Beta parameter used in activation function.",
1483+
AttributeProto::FLOAT, 0.0f)
1484+
.Attr("block_size",
1485+
"Size of each quantization block along the K (input feature) dimension. "
1486+
"Must be power of two and ≥ 16 (e.g., 16, 32, 64, 128). "
1487+
"If provided, both hidden_size and inter_size must be divisible by the block size. "
1488+
"Otherwise, there is no blocking and a whole column shares one scaling factor. ",
1489+
AttributeProto::INT,
1490+
OPTIONAL_VALUE)
14521491
.Input(0,
14531492
"input",
1454-
"2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape "
1455-
"(batch_size, sequence_length, hidden_size)",
1493+
"2D tensor with shape (num_tokens, hidden_size), or "
1494+
"3D tensor with shape (batch_size, sequence_length, hidden_size)",
1495+
"T")
1496+
.Input(1,
1497+
"router_probs",
1498+
"2D tensor with shape (num_tokens, num_experts)",
14561499
"T")
1457-
.Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
14581500
.Input(2,
14591501
"fc1_experts_weights",
1460-
"3D input tensor with shape (num_experts, inter_size, hidden_size), "
1461-
"or (num_experts, inter_size, hidden_size / 2) for 4 bits. "
1462-
"For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), "
1463-
"or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.",
1502+
"3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / pack_size), "
1503+
"The fusion_size is 2 for fused swiglu, or 1 otherwise. The pack_size is 8 / expert_weight_bits.",
14641504
"T1")
1465-
.Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2")
1505+
.Input(3,
1506+
"fc1_scales",
1507+
"2D tensor with shape (num_experts, fusion_size * inter_size), or "
1508+
"3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / block_size) when block_size is provided.",
1509+
"T2")
14661510
.Input(4,
14671511
"fc1_experts_bias",
1468-
"2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional)
1512+
"2D optional tensor with shape (num_experts, fusion_size * inter_size)", "T", OpSchema::Optional)
14691513
.Input(5,
14701514
"fc2_experts_weights",
1471-
"3D input tensor with shape (num_experts, hidden_size, inter_size) "
1472-
"or (num_experts, hidden_size, inter_size / 2) for 4 bits",
1515+
"3D tensor with shape (num_experts, hidden_size, inter_size / pack_size)",
14731516
"T1")
1474-
.Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2")
1517+
.Input(6,
1518+
"fc2_scales",
1519+
"2D tensor with shape (num_experts, hidden_size), or "
1520+
"3D tensor with shape (num_experts, hidden_size, inter_size / block_size) when block_size is provided.",
1521+
"T2")
14751522
.Input(7,
14761523
"fc2_experts_bias",
1477-
"2D optional input tensor with shape (num_experts, hidden_size)",
1524+
"2D optional tensor with shape (num_experts, hidden_size)",
14781525
"T",
14791526
OpSchema::Optional)
14801527
.Input(8,
14811528
"fc3_experts_weights",
1482-
"3D optional input tensor with shape (num_experts, inter_size, hidden_size) "
1483-
"or (num_experts, inter_size, hidden_size / 2)",
1529+
"3D optional tensor with shape (num_experts, inter_size, hidden_size / pack_size)",
14841530
"T1",
14851531
OpSchema::Optional)
14861532
.Input(9,
14871533
"fc3_scales",
1488-
"2D optional input tensor with shape (num_experts, inter_size)",
1534+
"2D optional tensor with shape (num_experts, inter_size), or "
1535+
"3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size) when block_size is provided.",
14891536
"T2",
14901537
OpSchema::Optional)
14911538
.Input(10,
14921539
"fc3_experts_bias",
1493-
"2D optional input tensor with shape (num_experts, inter_size)",
1540+
"2D optional tensor with shape (num_experts, inter_size)",
14941541
"T",
14951542
OpSchema::Optional)
14961543
.Output(0,
14971544
"output",
1498-
"2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape "
1499-
"(batch_size, sequence_length, hidden_size)",
1545+
"output tensor with same shape of input",
15001546
"T")
15011547
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
15021548
.TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.")

0 commit comments

Comments
 (0)