Skip to content

Commit d9b9168

Browse files
WebGPU JSEP: Make shader code not depend on input broadcasting patterns (microsoft#22536)
This PR make MatMul shaders not depend on inputs broadcasting pattern, but only depend on input ranks and their shape provided in uniform. This change fix the issue that currently shaders code are different for different broadcasting, but have identical cache key and results in wrong cache hit.
1 parent 4d614e1 commit d9b9168

File tree

6 files changed

+311
-235
lines changed

6 files changed

+311
-235
lines changed

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import { ShapeUtil } from '../../../util';
2525
import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types';
2626
import {
2727
createTensorShapeVariables,
28-
getBroadcastDims,
2928
IndicesHelper,
3029
inputVariable,
3130
internalVariable,
@@ -40,6 +39,7 @@ import {
4039
getActivationSnippet,
4140
InternalActivationAttributes,
4241
} from '../fuse-utils';
42+
import { convertOutputBatchIndicesToInputBatchIndices } from '../matmul-shaders';
4343

4444
import { typeSnippet } from './activation_util';
4545

@@ -373,42 +373,11 @@ const matMulReadWriteFnSource = (
373373
hasBias: boolean,
374374
applyActivation: string,
375375
variables: IndicesHelper[],
376-
batchShapes: Array<readonly number[]>,
377376
isChannelsLast = false,
378377
): string => {
379-
const [batchAShape, batchBShape, batchShape] = batchShapes;
380378
const [batchVariable, aVariable, bVariable, outputVariable] = variables;
381-
const broadCastADims = getBroadcastDims(batchAShape, batchShape);
382-
const broadCastBDims = getBroadcastDims(batchBShape, batchShape);
383379
const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor);
384-
const getAIndices = () => {
385-
const aRank = aVariable.rank;
386-
const batchRank = batchVariable.rank;
387-
let resStr = `var aIndices: ${aVariable.type.indices};`;
388-
for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
389-
resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
390-
}
391-
broadCastADims.forEach((i) => {
392-
resStr += `\naIndices[${i}] = 0;`;
393-
});
394-
resStr += `\naIndices[${aRank - 2}] = u32(row);
395-
aIndices[${aRank - 1}] = u32(colIn);`;
396-
return resStr;
397-
};
398-
const getBIndices = () => {
399-
const bRank = bVariable.rank;
400-
const batchRank = batchVariable.rank;
401-
let resStr = `var bIndices: ${bVariable.type.indices};`;
402-
for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
403-
resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
404-
}
405-
broadCastBDims.forEach((i) => {
406-
resStr += `\nbIndices[${i}] = 0;`;
407-
});
408-
resStr += `\nbIndices[${bRank - 2}] = u32(row);
409-
bIndices[${bRank - 1}] = u32(colIn);`;
410-
return resStr;
411-
};
380+
412381
const source = `
413382
fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet(
414383
component,
@@ -418,7 +387,16 @@ const matMulReadWriteFnSource = (
418387
let col = colIn * ${component};
419388
if(row < uniforms.dim_a_outer && col < uniforms.dim_inner)
420389
{
421-
${getAIndices()}
390+
var aIndices: ${aVariable.type.indices};
391+
${convertOutputBatchIndicesToInputBatchIndices(
392+
'aIndices',
393+
aVariable,
394+
aVariable.rank - 2,
395+
batchVariable.rank,
396+
'batchIndices',
397+
)}
398+
${aVariable.indicesSet('aIndices', aVariable.rank - 2, 'u32(row)')}
399+
${aVariable.indicesSet('aIndices', aVariable.rank - 1, 'u32(colIn)')}
422400
value = ${aVariable.getByIndices('aIndices')};
423401
}
424402
return value;
@@ -432,7 +410,16 @@ const matMulReadWriteFnSource = (
432410
let col = colIn * ${component};
433411
if(row < uniforms.dim_inner && col < uniforms.dim_b_outer)
434412
{
435-
${getBIndices()}
413+
var bIndices: ${bVariable.type.indices};
414+
${convertOutputBatchIndicesToInputBatchIndices(
415+
'bIndices',
416+
bVariable,
417+
bVariable.rank - 2,
418+
batchVariable.rank,
419+
'batchIndices',
420+
)}
421+
${bVariable.indicesSet('bIndices', bVariable.rank - 2, 'u32(row)')}
422+
${bVariable.indicesSet('bIndices', bVariable.rank - 1, 'u32(colIn)')}
436423
value = ${bVariable.getByIndices('bIndices')};
437424
}
438425
return value;
@@ -532,7 +519,6 @@ export const createMatmulProgramInfo = (
532519
hasBias,
533520
applyActivation,
534521
[batchDims, A, B, output],
535-
[outerDimsA, outerDimsB, outerDims],
536522
isChannelsLast,
537523
);
538524
return `

js/web/lib/wasm/jsep/webgpu/ops/common.ts

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -996,27 +996,3 @@ class ShaderHelperImpl implements ShaderHelper {
996996

997997
export const createShaderHelper = (dispatchGroup: [number, number, number], limits: GPUSupportedLimits) =>
998998
new ShaderHelperImpl(dispatchGroup, limits);
999-
1000-
/**
1001-
* This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40
1002-
* Returns the dimensions in the input shape that are broadcasted to
1003-
* produce the provided output shape.
1004-
*
1005-
* The returned dimensions are 0-indexed and sorted. An example:
1006-
* inShape = [4, 1, 3]
1007-
* outShape = [5, 4, 3, 3]
1008-
* result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
1009-
*/
1010-
export const getBroadcastDims = (inShape: readonly number[], outShape: readonly number[]): number[] => {
1011-
const inRank = inShape.length;
1012-
const dims: number[] = [];
1013-
for (let i = 0; i < inRank; i++) {
1014-
const dim = inRank - 1 - i;
1015-
const a = inShape[dim] || 1;
1016-
const b = outShape[outShape.length - 1 - i] || 1;
1017-
if (b > 1 && a === 1) {
1018-
dims.unshift(dim);
1019-
}
1020-
}
1021-
return dims;
1022-
};

js/web/lib/wasm/jsep/webgpu/ops/conv.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import { computeConv3DInfo, createConv3DNaiveProgramInfo } from './3rd-party/con
1111
import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu';
1212
import { createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo } from './conv-grouped';
1313
import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
14-
import { createNaiveMatmulProgramInfo } from './matmul';
14+
import { createNaiveMatmulProgramInfo } from './matmul-shaders';
1515
import { createTransposeProgramInfo } from './transpose';
1616

1717
export const calculateOutputShape = (
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
import { DataType } from '../../../wasm-common';
5+
import { TensorView } from '../../tensor-view';
6+
import { ShapeUtil } from '../../util';
7+
import { ProgramInfo, ProgramUniform } from '../types';
8+
9+
import {
10+
createTensorShapeVariables,
11+
getElementAt,
12+
getMaxComponents,
13+
IndicesHelper,
14+
inputVariable,
15+
internalVariable,
16+
outputVariable,
17+
ShaderHelper,
18+
tensorTypeToWsglStorageType,
19+
UniformsArrayType,
20+
} from './common';
21+
import {
22+
appendActivationUniforms,
23+
appendActivationUniformsData,
24+
getActivationSnippet,
25+
InternalActivationAttributes,
26+
} from './fuse-utils';
27+
28+
// Helper that convert output batch indices to input batch indices using only the rank and
29+
// the shape information in uniform
30+
export const convertOutputBatchIndicesToInputBatchIndices = (
31+
targetIndicesName: string,
32+
inputVariable: IndicesHelper,
33+
inputBatchRank: number,
34+
outputBatchRank: number,
35+
batchIndicesName: string,
36+
) => {
37+
// Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of
38+
// outputBatchRank should be ignored.
39+
const extendingInputRank = outputBatchRank - inputBatchRank;
40+
return `
41+
${Array.from({ length: inputBatchRank })
42+
.map(
43+
(_, i) => `
44+
if (${getElementAt(inputVariable.shape, i, inputVariable.rank)} != 1) {
45+
${inputVariable.indicesSet(targetIndicesName, i, getElementAt(batchIndicesName, i + extendingInputRank, outputBatchRank))}
46+
} else {
47+
${inputVariable.indicesSet(targetIndicesName, i, 0)}
48+
}`,
49+
)
50+
.join('')}
51+
`;
52+
};
53+
54+
export const createNaiveMatmulProgramInfo = (
55+
inputs: readonly TensorView[],
56+
activationAttributes: InternalActivationAttributes,
57+
outputShape: readonly number[],
58+
reshapedOutputShape?: readonly number[],
59+
isChannelsLast = false /* only used for conv2dByMatMul*/,
60+
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
61+
): ProgramInfo => {
62+
const aShape = inputs[0].dims;
63+
const bShape = inputs[1].dims;
64+
65+
const M = aShape[aShape.length - 2];
66+
const N = bShape[bShape.length - 1];
67+
const K = aShape[aShape.length - 1];
68+
const components = getMaxComponents(N);
69+
const aComponents = getMaxComponents(K);
70+
const outputNumber = getMaxComponents(M);
71+
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
72+
const hasBias = inputs.length > 2;
73+
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
74+
const batchSize = ShapeUtil.size(outerDims);
75+
const outputShapeInShader = [batchSize, M, N];
76+
77+
const programUniforms: ProgramUniform[] = [
78+
{ type: DataType.uint32, data: outputSize },
79+
{ type: DataType.uint32, data: M },
80+
{ type: DataType.uint32, data: N },
81+
{ type: DataType.uint32, data: K },
82+
];
83+
appendActivationUniformsData(activationAttributes, programUniforms);
84+
programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape));
85+
if (hasBias) {
86+
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
87+
}
88+
programUniforms.push(...createTensorShapeVariables(outputShapeInShader));
89+
90+
const getShaderSource = (shaderHelper: ShaderHelper) => {
91+
const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length);
92+
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
93+
const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
94+
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
95+
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
96+
const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType);
97+
const inputVariables = [a, b];
98+
let processBias = '';
99+
if (hasBias) {
100+
const biasComponents = isChannelsLast ? components : 1;
101+
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
102+
processBias = `${
103+
isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);`
104+
}`;
105+
}
106+
107+
const uniforms: UniformsArrayType = [
108+
{ name: 'output_size', type: 'u32' },
109+
{ name: 'M', type: 'u32' },
110+
{ name: 'N', type: 'u32' },
111+
{ name: 'K', type: 'u32' },
112+
];
113+
appendActivationUniforms(activationAttributes, uniforms);
114+
115+
const calcResult = (): string => {
116+
let calcStr = `var a_data: ${a.type.value};`;
117+
for (let i = 0; i < aComponents; i++) {
118+
calcStr += `
119+
let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`;
120+
}
121+
for (let i = 0; i < outputNumber; i++) {
122+
calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`;
123+
124+
for (let j = 0; j < aComponents; j++) {
125+
calcStr += `
126+
values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`;
127+
}
128+
}
129+
return calcStr;
130+
};
131+
132+
return `
133+
${shaderHelper
134+
.registerUniforms(uniforms)
135+
.registerInternalVariables(batchDims)
136+
.declareVariables(...inputVariables, output)}
137+
${shaderHelper.mainStart()}
138+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
139+
let col = (global_idx % (uniforms.N / ${components})) * ${components};
140+
var index1 = global_idx / (uniforms.N / ${components});
141+
let stride1 = uniforms.M / ${outputNumber};
142+
let row = (index1 % stride1) * ${outputNumber};
143+
let batch = index1 / stride1;
144+
145+
${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
146+
147+
var a_indices: ${a.type.indices};
148+
${convertOutputBatchIndicesToInputBatchIndices('a_indices', a, a.rank - 2, batchDims.rank, 'batch_indices')}
149+
${a.indicesSet('a_indices', a.rank - 2, 0)}
150+
${a.indicesSet('a_indices', a.rank - 1, 0)}
151+
let a_offset = ${a.indicesToOffset('a_indices')};
152+
153+
var b_indices: ${b.type.indices};
154+
${convertOutputBatchIndicesToInputBatchIndices('b_indices', b, b.rank - 2, batchDims.rank, 'batch_indices')}
155+
${b.indicesSet('b_indices', b.rank - 2, 0)}
156+
${b.indicesSet('b_indices', b.rank - 1, 0)}
157+
let b_offset = ${b.indicesToOffset('b_indices')};
158+
var values: array<${output.type.value}, ${outputNumber}>;
159+
for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
160+
${calcResult()}
161+
}
162+
for (var i = 0u; i < ${outputNumber}u; i++) {
163+
var value = values[i];
164+
${processBias}
165+
${applyActivation}
166+
let cur_indices = ${output.type.indices}(batch, row + i, col);
167+
let offset = ${output.indicesToOffset('cur_indices')};
168+
${output.setByOffset(`offset / ${components}`, 'value')};
169+
}
170+
}
171+
`;
172+
};
173+
return {
174+
name: 'MatMulNaive',
175+
shaderCache: {
176+
hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`,
177+
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
178+
},
179+
getRunData: () => ({
180+
outputs: [
181+
{
182+
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
183+
dataType: inputs[0].dataType,
184+
},
185+
],
186+
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
187+
programUniforms,
188+
}),
189+
getShaderSource,
190+
};
191+
};

0 commit comments

Comments
 (0)