Skip to content

Commit 37801a5

Browse files
syt123450pyu10055
authored andcommitted
Add avgPool3d & maxPool3d ops to graph model executor (#375)
* Add maxPool3d & avgPool3d to executor * update version
1 parent 998cc54 commit 37801a5

File tree

7 files changed

+185
-6
lines changed

7 files changed

+185
-6
lines changed

docs/supported_ops.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@
9494
|Tensorflow Op Name|Tensorflow.js Op Name|
9595
|---|---|
9696
|AvgPool|avgPool|
97+
|AvgPool3D|avgPool3d|
9798
|Conv1D|conv1d|
9899
|Conv2D|conv2d|
99100
|Conv2DBackpropInput|conv2dTranspose|
100101
|DepthwiseConv2d|depthwiseConv2d|
101102
|DepthwiseConv2dNative|depthwiseConv2d|
102103
|MaxPool|maxPool|
104+
|MaxPool3D|maxPool3d|
103105
|Not mapped|pool|
104106
|Not mapped|separableConv2d|
105107

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
},
1616
"license": "Apache-2.0",
1717
"peerDependencies": {
18-
"@tensorflow/tfjs-core": "1.2.7"
18+
"@tensorflow/tfjs-core": "1.2.8"
1919
},
2020
"devDependencies": {
21-
"@tensorflow/tfjs-core": "1.2.7",
21+
"@tensorflow/tfjs-core": "1.2.8",
2222
"@types/deep-equal": "^1.0.1",
2323
"@types/jasmine": "~2.8.6",
2424
"@types/long": "~3.0.32",

python/tensorflowjs/op_list/convolution.json

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,86 @@
7979
}
8080
]
8181
},
82+
{
83+
"tfOpName": "AvgPool3D",
84+
"category": "convolution",
85+
"inputs": [
86+
{
87+
"start": 0,
88+
"name": "x",
89+
"type": "tensor"
90+
}
91+
],
92+
"attrs": [
93+
{
94+
"tfName": "strides",
95+
"name": "strides",
96+
"type": "number[]"
97+
},
98+
{
99+
"tfName": "padding",
100+
"name": "pad",
101+
"type": "string"
102+
},
103+
{
104+
"tfName": "data_format",
105+
"name": "dataFormat",
106+
"type": "string",
107+
"notSupported": true
108+
},
109+
{
110+
"tfName": "ksize",
111+
"name": "kernelSize",
112+
"type": "number[]"
113+
},
114+
{
115+
"tfName": "T",
116+
"name": "dtype",
117+
"type": "dtype",
118+
"notSupported": true
119+
}
120+
]
121+
},
122+
{
123+
"tfOpName": "MaxPool3D",
124+
"category": "convolution",
125+
"inputs": [
126+
{
127+
"start": 0,
128+
"name": "x",
129+
"type": "tensor"
130+
}
131+
],
132+
"attrs": [
133+
{
134+
"tfName": "strides",
135+
"name": "strides",
136+
"type": "number[]"
137+
},
138+
{
139+
"tfName": "padding",
140+
"name": "pad",
141+
"type": "string"
142+
},
143+
{
144+
"tfName": "data_format",
145+
"name": "dataFormat",
146+
"type": "string",
147+
"notSupported": true
148+
},
149+
{
150+
"tfName": "ksize",
151+
"name": "kernelSize",
152+
"type": "number[]"
153+
},
154+
{
155+
"tfName": "T",
156+
"name": "dtype",
157+
"type": "dtype",
158+
"notSupported": true
159+
}
160+
]
161+
},
82162
{
83163
"tfOpName": "Conv1D",
84164
"category": "convolution",

src/operations/executors/convolution_executor.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,35 @@ export let executeOp: InternalOpExecutor =
119119
[kernelSize[1], kernelSize[2]], [stride[1], stride[2]],
120120
pad as 'valid' | 'same')];
121121
}
122+
123+
case 'AvgPool3D': {
124+
const stride =
125+
getParamValue('strides', node, tensorMap, context) as number[];
126+
const pad = getParamValue('pad', node, tensorMap, context);
127+
const kernelSize =
128+
getParamValue('kernelSize', node, tensorMap, context) as number[];
129+
130+
return [tfc.avgPool3d(
131+
getParamValue('x', node, tensorMap, context) as tfc.Tensor5D,
132+
[kernelSize[1], kernelSize[2], kernelSize[3]],
133+
[stride[1], stride[2], stride[3]],
134+
pad as 'valid' | 'same')];
135+
}
136+
137+
case 'MaxPool3D': {
138+
const stride =
139+
getParamValue('strides', node, tensorMap, context) as number[];
140+
const pad = getParamValue('pad', node, tensorMap, context);
141+
const kernelSize =
142+
getParamValue('kernelSize', node, tensorMap, context) as number[];
143+
144+
return [tfc.maxPool3d(
145+
getParamValue('x', node, tensorMap, context) as tfc.Tensor5D,
146+
[kernelSize[1], kernelSize[2], kernelSize[3]],
147+
[stride[1], stride[2], stride[3]],
148+
pad as 'valid' | 'same')];
149+
}
150+
122151
default:
123152
throw TypeError(`Node type ${node.op} is not implemented`);
124153
}

src/operations/executors/convolution_executor_test.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,5 +156,37 @@ describe('convolution', () => {
156156
input1[0], input2[0], [2, 2], 'same', 'NHWC', [2, 2]);
157157
});
158158
});
159+
160+
describe('AvgPool3D', () => {
161+
it('should call tfc.avgPool3d', () => {
162+
spyOn(tfc, 'avgPool3d');
163+
node.op = 'AvgPool3D';
164+
node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 2, 1]);
165+
node.attrParams['pad'] = createStrAttr('same');
166+
node.attrParams['kernelSize'] =
167+
createNumericArrayAttr([1, 2, 2, 2, 1]);
168+
169+
executeOp(node, {input}, context);
170+
171+
expect(tfc.avgPool3d)
172+
.toHaveBeenCalledWith(input[0], [2, 2, 2], [2, 2, 2], 'same');
173+
});
174+
});
175+
176+
describe('MaxPool3D', () => {
177+
it('should call tfc.maxPool3d', () => {
178+
spyOn(tfc, 'maxPool3d');
179+
node.op = 'MaxPool3D';
180+
node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 2, 1]);
181+
node.attrParams['pad'] = createStrAttr('same');
182+
node.attrParams['kernelSize'] =
183+
createNumericArrayAttr([1, 2, 2, 2, 1]);
184+
185+
executeOp(node, {input}, context);
186+
187+
expect(tfc.maxPool3d)
188+
.toHaveBeenCalledWith(input[0], [2, 2, 2], [2, 2, 2], 'same');
189+
});
190+
});
159191
});
160192
});

src/operations/op_list/convolution.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,42 @@ export const json: OpMapper[] = [
5454
{'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true}
5555
]
5656
},
57+
{
58+
'tfOpName': 'AvgPool3D',
59+
'category': 'convolution',
60+
'inputs': [
61+
{'start': 0, 'name': 'x', 'type': 'tensor'},
62+
],
63+
'attrs': [
64+
{'tfName': 'strides', 'name': 'strides', 'type': 'number[]'},
65+
{'tfName': 'padding', 'name': 'pad', 'type': 'string'}, {
66+
'tfName': 'data_format',
67+
'name': 'dataFormat',
68+
'type': 'string',
69+
'notSupported': true
70+
},
71+
{'tfName': 'ksize', 'name': 'kernelSize', 'type': 'number[]'},
72+
{'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true}
73+
]
74+
},
75+
{
76+
'tfOpName': 'MaxPool3D',
77+
'category': 'convolution',
78+
'inputs': [
79+
{'start': 0, 'name': 'x', 'type': 'tensor'},
80+
],
81+
'attrs': [
82+
{'tfName': 'strides', 'name': 'strides', 'type': 'number[]'},
83+
{'tfName': 'padding', 'name': 'pad', 'type': 'string'}, {
84+
'tfName': 'data_format',
85+
'name': 'dataFormat',
86+
'type': 'string',
87+
'notSupported': true
88+
},
89+
{'tfName': 'ksize', 'name': 'kernelSize', 'type': 'number[]'},
90+
{'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true}
91+
]
92+
},
5793
{
5894
'tfOpName': 'Conv1D',
5995
'category': 'convolution',

yarn.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@
5555
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570"
5656
integrity sha1-p3c2C1s5oaLlEG+OhY8v0tBgxXA=
5757

58-
"@tensorflow/tfjs-core@1.2.7":
59-
version "1.2.7"
60-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.2.7.tgz#522328de16470aa9f7c15b91e4b68616f425002a"
61-
integrity sha512-RsXavYKMc0MOcCmOyD7HE8am1tWlDGXl0nJbsdib7ubmvMuH6KnrZ302eTYV7k1RMq+/ukkioJmCcw13hopuHQ==
58+
"@tensorflow/tfjs-core@1.2.8":
59+
version "1.2.8"
60+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.2.8.tgz#d6873b88522f8cf25d34c10afd095866578d7d92"
61+
integrity sha512-lWV4vAnXAAmahXpCWBwdGGW9HO6iNw9pUeVYih7pDXeJahMk3OJs6SgjRNhwn+ldsGwRoorR0/RHg0yNLmqWxQ==
6262
dependencies:
6363
"@types/offscreencanvas" "~2019.3.0"
6464
"@types/seedrandom" "2.4.27"

0 commit comments

Comments
 (0)