Skip to content

Commit f42d078

Browse files
committed
Merge branch 'kaczmarj-add/conv3d'
2 parents 37801a5 + dcdd07b commit f42d078

File tree

4 files changed

+100
-4
lines changed

4 files changed

+100
-4
lines changed

python/tensorflowjs/op_list/convolution.json

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,5 +371,44 @@
371371
"type": "number[]"
372372
}
373373
]
374+
},
375+
{
376+
"tfOpName": "Conv3D",
377+
"category": "convolution",
378+
"inputs": [
379+
{
380+
"start": 0,
381+
"name": "x",
382+
"type": "tensor"
383+
},
384+
{
385+
"start": 1,
386+
"name": "filter",
387+
"type": "tensor"
388+
}
389+
],
390+
"attrs": [
391+
{
392+
"tfName": "strides",
393+
"name": "strides",
394+
"type": "number[]"
395+
},
396+
{
397+
"tfName": "padding",
398+
"name": "pad",
399+
"type": "string"
400+
},
401+
{
402+
"tfName": "data_format",
403+
"name": "dataFormat",
404+
"type": "string",
405+
"defaultValue": "NHWC"
406+
},
407+
{
408+
"tfName": "dilations",
409+
"name": "dilations",
410+
"type": "number[]"
411+
}
412+
]
374413
}
375414
]

src/operations/executors/convolution_executor.ts

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,24 @@ export let executeOp: InternalOpExecutor =
9191
[stride[1], stride[2]], pad as 'valid' | 'same',
9292
dataFormat as 'NHWC' | 'NCHW', [dilations[1], dilations[2]])];
9393
}
94+
case 'Conv3D': {
95+
const stride =
96+
getParamValue('strides', node, tensorMap, context) as number[];
97+
const pad = getParamValue('pad', node, tensorMap, context);
98+
const dataFormat =
99+
(getParamValue('dataFormat', node, tensorMap, context) as string)
100+
.toUpperCase();
101+
const dilations =
102+
getParamValue('dilations', node, tensorMap, context) as number[];
103+
return [tfc.conv3d(
104+
getParamValue('x', node, tensorMap, context) as tfc.Tensor4D |
105+
tfc.Tensor<tfc.Rank.R5>,
106+
getParamValue('filter', node, tensorMap, context) as
107+
tfc.Tensor<tfc.Rank.R5>,
108+
[stride[1], stride[2], stride[3]], pad as 'valid' | 'same',
109+
dataFormat as 'NDHWC' | 'NCDHW',
110+
[dilations[1], dilations[2], dilations[3]])];
111+
}
94112

95113
case 'AvgPool': {
96114
const stride =
@@ -130,8 +148,7 @@ export let executeOp: InternalOpExecutor =
130148
return [tfc.avgPool3d(
131149
getParamValue('x', node, tensorMap, context) as tfc.Tensor5D,
132150
[kernelSize[1], kernelSize[2], kernelSize[3]],
133-
[stride[1], stride[2], stride[3]],
134-
pad as 'valid' | 'same')];
151+
[stride[1], stride[2], stride[3]], pad as 'valid' | 'same')];
135152
}
136153

137154
case 'MaxPool3D': {
@@ -144,8 +161,7 @@ export let executeOp: InternalOpExecutor =
144161
return [tfc.maxPool3d(
145162
getParamValue('x', node, tensorMap, context) as tfc.Tensor5D,
146163
[kernelSize[1], kernelSize[2], kernelSize[3]],
147-
[stride[1], stride[2], stride[3]],
148-
pad as 'valid' | 'same')];
164+
[stride[1], stride[2], stride[3]], pad as 'valid' | 'same')];
149165
}
150166

151167
default:

src/operations/executors/convolution_executor_test.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,29 @@ describe('convolution', () => {
157157
});
158158
});
159159

160+
describe('Conv3d', () => {
161+
it('should call tfc.conv3d', () => {
162+
spyOn(tfc, 'conv3d');
163+
node.op = 'Conv3D';
164+
node.category = 'convolution';
165+
node.inputParams['filter'] = createTensorAttr(1);
166+
node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 2, 1]);
167+
node.attrParams['pad'] = createStrAttr('same');
168+
node.attrParams['dataFormat'] = createStrAttr('NHWC');
169+
node.attrParams['dilations'] = createNumericArrayAttr([1, 2, 2, 2, 1]);
170+
171+
const input1 = [tfc.scalar(1.0)];
172+
const input2 = [tfc.scalar(1.0)];
173+
node.inputNames = ['input1', 'input2'];
174+
175+
executeOp(node, {input1, input2}, context);
176+
177+
expect(tfc.conv3d)
178+
.toHaveBeenCalledWith(
179+
input1[0], input2[0], [2, 2, 2], 'same', 'NHWC', [2, 2, 2]);
180+
});
181+
});
182+
160183
describe('AvgPool3D', () => {
161184
it('should call tfc.avgPool3d', () => {
162185
spyOn(tfc, 'avgPool3d');

src/operations/op_list/convolution.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,5 +186,23 @@ export const json: OpMapper[] = [
186186
},
187187
{'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]'}
188188
]
189+
},
190+
{
191+
'tfOpName': 'Conv3D',
192+
'category': 'convolution',
193+
'inputs': [
194+
{'start': 0, 'name': 'x', 'type': 'tensor'},
195+
{'start': 1, 'name': 'filter', 'type': 'tensor'},
196+
],
197+
'attrs': [
198+
{'tfName': 'strides', 'name': 'strides', 'type': 'number[]'},
199+
{'tfName': 'padding', 'name': 'pad', 'type': 'string'}, {
200+
'tfName': 'data_format',
201+
'name': 'dataFormat',
202+
'type': 'string',
203+
'defaultValue': 'NHWC'
204+
},
205+
{'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]'}
206+
],
189207
}
190208
];

0 commit comments

Comments
 (0)