Skip to content

Commit bff4539

Browse files
qjia7annxingyuan
authored andcommitted
webgpu: Fix the compile error in conv2d_naive shader (#2002)
BUG When switch to the naive method of conv2d, below error will be met: Error: Shader compilation failed: file:124: error: 'getX' : no matching overloaded function found. Error: Shader compilation failed: file:131: error: 'getW' : no matching overloaded function found. The fixing uses the right get* methods. And uses int instead of uint as the parameter type.
1 parent ee8b2ae commit bff4539

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,19 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
4444
() => 'TODO: Dilation is unimplemented');
4545

4646
this.userCode = `
47-
float readInp(uint batch, uint row, uint col, uint chan) {
47+
float readInp(int batch, int row, int col, int chan) {
4848
ivec4 coord = ivec4(batch, row, col, chan);
49-
return coordIsValid(coord, xShape) ? getX(coord) : 0;
49+
return coordIsValid(coord, xShape) ?
50+
getX(batch, row, col, chan) : 0;
5051
}
5152
52-
float readFilt(uint row, uint col, uint xChannel, uint outChannel) {
53-
ivec4 shape = ivec4(filterDims, xShape[3], outShape[3]);
54-
return coordIsValid(coord, shape) ?
53+
float readFilt(int row, int col, int xChannel, int outChannel) {
54+
ivec4 coord = ivec4(row, col, xChannel, outChannel);
55+
return coordIsValid(coord, wShape) ?
5556
getW(row, col, xChannel, outChannel) : 0;
5657
}
5758
58-
void writeResult(uint batch, uint row, uint col, uint chan, float value) {
59+
void writeResult(int batch, int row, int col, int chan, float value) {
5960
ivec4 coord = ivec4(batch, row, col, chan);
6061
if (coordIsValid(coord, outShape)) {
6162
setOutput(batch, row, col, chan, value);

0 commit comments

Comments
 (0)