Skip to content

Commit 868accc

Browse files
authored
Add more float16 data type support for existing models (#333)
Float16 data type can be easily enabled by inserting cast to inputs and outputs and converting the buffer weights from fp32 to fp16 during graph loading. This PR enables more float16 data type support for image classification and object detection. BTW, fixes several minor issues.
1 parent 06ae438 commit 868accc

20 files changed

+359
-327
lines changed

common/utils.js

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,24 +435,27 @@ export function permuteData(array, dims, axes) {
435435
export async function getDefaultLayout(deviceType) {
436436
const context = await navigator.ml.createContext({deviceType});
437437
const limits = context.opSupportLimits();
438-
return limits.preferredInputLayout ?? 'nchw';
438+
439+
const preferredLayout = limits.preferredInputLayout ?? 'nchw';
440+
context.destroy();
441+
return preferredLayout;
439442
}
440443

441444
/**
442445
* Display available models based on device type and data type.
443446
* @param {Object} modelList list of available models.
444447
* @param {Array} modelIds list of model ids.
445-
* @param {String} deviceType 'cpu', 'gpu' or 'npu'.
448+
* @param {String} layout 'nchw' or 'nhwc'.
446449
* @param {String} dataType 'float32', 'float16', or ''.
447450
*/
448-
export function displayAvailableModels(
449-
modelList, modelIds, deviceType, dataType) {
451+
export function displayAvailableModels(modelList, modelIds, layout, dataType) {
450452
let models = [];
451453
if (dataType == '') {
452-
models = models.concat(modelList[deviceType]['float32']);
453-
models = models.concat(modelList[deviceType]['float16']);
454+
models = models.concat(modelList[layout]['float32'] ?? []);
455+
models = models.concat(modelList[layout]['float16'] ?? []);
456+
models = models.concat(modelList[layout]['uint8'] ?? []);
454457
} else {
455-
models = models.concat(modelList[deviceType][dataType]);
458+
models = models.concat(modelList[layout][dataType] ?? []);
456459
}
457460
// Remove duplicate ids.
458461
models = [...new Set(models)];

facial_landmark_detection/ssd_mobilenetv2_face_nchw.js

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
66
export class SsdMobilenetV2FaceNchw {
77
constructor() {
88
this.context_ = null;
9-
this.deviceType_ = null;
109
this.builder_ = null;
1110
this.graph_ = null;
1211
this.inputTensor_ = null;
@@ -117,7 +116,6 @@ ${nameArray[1]}`;
117116

118117
async load(contextOptions) {
119118
this.context_ = await navigator.ml.createContext(contextOptions);
120-
this.deviceType_ = contextOptions.deviceType;
121119
this.builder_ = new MLGraphBuilder(this.context_);
122120
const inputDesc = {
123121
dataType: 'float32',

facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
66
export class SsdMobilenetV2FaceNhwc {
77
constructor() {
88
this.context_ = null;
9-
this.deviceType_ = null;
109
this.builder_ = null;
1110
this.graph_ = null;
1211
this.inputTensor_ = null;
@@ -128,7 +127,6 @@ ${nameArray[1]}`;
128127

129128
async load(contextOptions) {
130129
this.context_ = await navigator.ml.createContext(contextOptions);
131-
this.deviceType_ = contextOptions.deviceType;
132130
this.builder_ = new MLGraphBuilder(this.context_);
133131
const inputDesc = {
134132
dataType: 'float32',

image_classification/index.html

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@
3030
<span>Backend</span>
3131
</div>
3232
<div class="col-md-auto">
33-
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
33+
<div class="btn-group-toggle" data-toggle="buttons" id="deviceTypeBtns">
3434
<label class="btn btn-outline-info custom" name="webnn">
35-
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
35+
<input type="radio" name="deviceType" id="cpu" autocomplete="off">WebNN (CPU)
3636
</label>
3737
<label class="btn btn-outline-info custom" name="webnn">
38-
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off">WebNN (GPU)
38+
<input type="radio" name="deviceType" id="gpu" autocomplete="off">WebNN (GPU)
3939
</label>
4040
<label class="btn btn-outline-info custom" name="webnn">
41-
<input type="radio" name="backend" id="webnn_npu" autocomplete="off">WebNN (NPU)
41+
<input type="radio" name="deviceType" id="npu" autocomplete="off">WebNN (NPU)
4242
</label>
4343
</div>
4444
</div>
@@ -65,13 +65,13 @@
6565
<div class="col-md-auto">
6666
<div class="btn-group-toggle" data-toggle="buttons" id="dataTypeBtns">
6767
<label class="btn btn-outline-info" id="float32Label" active>
68-
<input type="radio" name="layout" id="float32" autocomplete="off" checked>Float32
68+
<input type="radio" name="dataType" id="float32" autocomplete="off" checked>Float32
6969
</label>
7070
<label class="btn btn-outline-info" id="float16Label">
71-
<input type="radio" name="layout" id="float16" autocomplete="off">Float16
71+
<input type="radio" name="dataType" id="float16" autocomplete="off">Float16
7272
</label>
7373
<label class="btn btn-outline-info" id="uint8Label">
74-
<input type="radio" name="layout" id="uint8" autocomplete="off">Uint8
74+
<input type="radio" name="dataType" id="uint8" autocomplete="off">Uint8
7575
</label>
7676
</div>
7777
</div>

image_classification/main.js

Lines changed: 76 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@ const imgElement = document.getElementById('feedElement');
1818
imgElement.src = './images/test.jpg';
1919
const camElement = document.getElementById('feedMediaElement');
2020
let modelName = '';
21-
let modelId = '';
22-
let layout = 'nhwc';
21+
let layout = '';
2322
let dataType = 'float32';
24-
let instanceType = modelName + layout;
23+
let instanceType = '';
2524
let rafReq;
26-
let isFirstTimeLoad = true;
2725
let inputType = 'image';
2826
let netInstance = null;
2927
let labels = null;
@@ -33,48 +31,39 @@ let buildTime = 0;
3331
let computeTime = 0;
3432
let inputOptions;
3533
let deviceType = '';
36-
let lastdeviceType = '';
37-
let backend = '';
38-
let lastBackend = '';
3934
let stopRender = true;
4035
let isRendering = false;
4136
const disabledSelectors = ['#tabs > li', '.btn'];
4237
const modelIds = [
38+
'efficientnet',
4339
'mobilenet',
44-
'squeezenet',
45-
'resnet50v2',
4640
'resnet50v1',
47-
'efficientnet',
41+
'resnet50v2',
42+
'squeezenet',
4843
];
4944
const modelList = {
50-
'cpu': {
45+
'nhwc': {
5146
'float32': [
5247
'mobilenet',
5348
'squeezenet',
5449
'resnet50v2',
5550
],
56-
'uint8': [
57-
'mobilenet',
58-
],
59-
},
60-
'gpu': {
61-
'float32': [
51+
'float16': [
6252
'mobilenet',
6353
'squeezenet',
6454
'resnet50v2',
6555
],
66-
'float16': [
67-
'efficientnet',
56+
'uint8': [
6857
'mobilenet',
69-
'resnet50v1',
7058
],
7159
},
72-
'npu': {
73-
'float16': [
74-
'efficientnet',
60+
'nchw': {
61+
'float32': [
7562
'mobilenet',
76-
'resnet50v1',
63+
'squeezenet',
64+
'resnet50v2',
7765
],
66+
'float16': modelIds,
7867
},
7968
};
8069

@@ -87,58 +76,44 @@ async function fetchLabels(url) {
8776
$(document).ready(async () => {
8877
$('.icdisplay').hide();
8978
if (await utils.isWebNN()) {
90-
$('#webnn_cpu').click();
79+
$('#cpu').click();
9180
} else {
9281
console.log(utils.webNNNotSupportMessage());
9382
ui.addAlert(utils.webNNNotSupportMessageHTML());
9483
}
84+
layout = await utils.getDefaultLayout('cpu');
9585
});
9686

97-
$('#backendBtns .btn').on('change', async (e) => {
87+
$('#deviceTypeBtns .btn').on('change', async (e) => {
9888
if (inputType === 'camera') {
9989
await stopCamRender();
10090
}
101-
[backend, deviceType] = $(e.target).attr('id').split('_');
91+
deviceType = $(e.target).attr('id');
10292
layout = await utils.getDefaultLayout(deviceType);
103-
// Only show the supported models for each deviceType. Now fp16 nchw models
104-
// are only supported on gpu/npu.
105-
if (deviceType == 'gpu') {
106-
ui.handleBtnUI('#float16Label', false);
107-
ui.handleBtnUI('#float32Label', false);
108-
ui.handleBtnUI('#uint8Label', true);
109-
$('#float32').click();
110-
utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
111-
} else if (deviceType == 'npu') {
112-
ui.handleBtnUI('#float16Label', false);
93+
const showUint8 = layout === 'nhwc' ? true : false;
94+
ui.handleBtnUI('#uint8Label', !showUint8);
95+
ui.handleBtnUI('#float16Label', false);
96+
// Only show the supported models for each deviceType.
97+
if (deviceType == 'npu') {
11398
ui.handleBtnUI('#float32Label', true);
114-
ui.handleBtnUI('#uint8Label', true);
11599
$('#float16').click();
116-
utils.displayAvailableModels(modelList, modelIds, deviceType, 'float16');
117100
} else {
118-
ui.handleBtnUI('#float16Label', true);
119101
ui.handleBtnUI('#float32Label', false);
120-
ui.handleBtnUI('#uint8Label', false);
121102
$('#float32').click();
122-
utils.displayAvailableModels(modelList, modelIds, deviceType, 'float32');
123103
}
124104

105+
utils.displayAvailableModels(modelList, modelIds, layout, dataType);
125106
// Uncheck selected model
126-
if (modelId != '') {
127-
$(`#${modelId}`).parent().removeClass('active');
107+
if (modelName != '') {
108+
$(`#${modelName}`).parent().removeClass('active');
128109
}
129110
});
130111

131112
$('#modelBtns .btn').on('change', async (e) => {
132113
if (inputType === 'camera') {
133114
await stopCamRender();
134115
}
135-
modelId = $(e.target).attr('id');
136-
modelName = modelId;
137-
if (dataType == 'float16') {
138-
modelName += 'fp16';
139-
} else if (dataType == 'uint8') {
140-
modelName += 'uint8';
141-
}
116+
modelName = $(e.target).attr('id');
142117

143118
await main();
144119
});
@@ -152,11 +127,15 @@ $('#modelBtns .btn').on('change', async (e) => {
152127
// });
153128

154129
$('#dataTypeBtns .btn').on('change', async (e) => {
130+
if (inputType === 'camera') {
131+
await stopCamRender();
132+
}
133+
155134
dataType = $(e.target).attr('id');
156-
utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
135+
utils.displayAvailableModels(modelList, modelIds, layout, dataType);
157136
// Uncheck selected model
158-
if (modelId != '') {
159-
$(`#${modelId}`).parent().removeClass('active');
137+
if (modelName != '') {
138+
$(`#${modelName}`).parent().removeClass('active');
160139
}
161140
});
162141

@@ -299,50 +278,64 @@ function showPerfResult(medianComputeTime = undefined) {
299278
}
300279
}
301280

302-
function constructNetObject(type) {
303-
const netObject = {
304-
'mobilenetfp16nchw': new MobileNetV2Nchw('float16'),
305-
'resnet50v1fp16nchw': new ResNet50V1FP16Nchw(),
306-
'efficientnetfp16nchw': new EfficientNetFP16Nchw(),
307-
'mobilenetnchw': new MobileNetV2Nchw(),
308-
'mobilenetnhwc': new MobileNetV2Nhwc(),
309-
'mobilenetuint8nhwc': new MobileNetV2Uint8Nhwc(),
310-
'squeezenetnchw': new SqueezeNetNchw(),
311-
'squeezenetnhwc': new SqueezeNetNhwc(),
312-
'resnet50v2nchw': new ResNet50V2Nchw(),
313-
'resnet50v2nhwc': new ResNet50V2Nhwc(),
314-
};
281+
function constructNetObject(modelName, layout, dataType) {
282+
switch (modelName) {
283+
case 'efficientnet':
284+
if (layout == 'nchw' && dataType == 'float16') {
285+
return new EfficientNetFP16Nchw();
286+
}
287+
break;
288+
case 'mobilenet':
289+
if (layout == 'nhwc' && dataType == 'uint8') {
290+
return new MobileNetV2Uint8Nhwc();
291+
} else if (dataType != 'uint8') {
292+
return layout == 'nhwc' ?
293+
new MobileNetV2Nhwc(dataType) : new MobileNetV2Nchw(dataType);
294+
}
295+
break;
296+
case 'resnet50v1':
297+
if (layout == 'nchw' && dataType == 'float16') {
298+
return new ResNet50V1FP16Nchw();
299+
}
300+
break;
301+
case 'resnet50v2':
302+
if (dataType != 'uint8') {
303+
return layout == 'nhwc' ?
304+
new ResNet50V2Nhwc(dataType) : new ResNet50V2Nchw(dataType);
305+
}
306+
break;
307+
case 'squeezenet':
308+
if (dataType != 'uint8') {
309+
return layout == 'nhwc' ?
310+
new SqueezeNetNhwc() : new SqueezeNetNchw();
311+
}
312+
break;
313+
}
315314

316-
return netObject[type];
315+
throw new Error(`Unknown model, name: ${modelName}, layout: ${layout}, ` +
316+
`dataType: ${dataType}`);
317317
}
318318

319319
async function main() {
320320
try {
321321
if (modelName === '') return;
322322
ui.handleClick(disabledSelectors, true);
323-
if (isFirstTimeLoad) $('#hint').hide();
323+
if (instanceType == '') $('#hint').hide();
324324
let start;
325325
const [numRuns, powerPreference] = utils.getUrlParams();
326326

327327
// Only do load() and build() when model first time loads,
328-
// there's new model choosed, backend changed or device changed
329-
if (isFirstTimeLoad || instanceType !== modelName + layout ||
330-
lastdeviceType != deviceType || lastBackend != backend) {
331-
if (lastdeviceType != deviceType || lastBackend != backend) {
332-
// Set backend and device
333-
lastdeviceType = lastdeviceType != deviceType ?
334-
deviceType : lastdeviceType;
335-
lastBackend = lastBackend != backend ? backend : lastBackend;
336-
}
337-
instanceType = modelName + layout;
338-
netInstance = constructNetObject(instanceType);
328+
// there's new model choosed
329+
if (instanceType !== modelName + dataType + layout + deviceType) {
330+
instanceType = modelName + dataType + layout + deviceType;
331+
netInstance = constructNetObject(modelName, layout, dataType);
339332
inputOptions = netInstance.inputOptions;
340333
labels = await fetchLabels(inputOptions.labelUrl);
341-
isFirstTimeLoad = false;
342-
console.log(`- Model name: ${modelName}, Model layout: ${layout} -`);
334+
console.log(
335+
`- Model: ${modelName} - ${layout} - ${dataType} - ${deviceType}`);
343336
// UI shows model loading progress
344337
await ui.showProgressComponent('current', 'pending', 'pending');
345-
console.log('- Loading weights... ');
338+
console.log('- Loading weights...');
346339
const contextOptions = {deviceType};
347340
if (powerPreference) {
348341
contextOptions['powerPreference'] = powerPreference;

image_classification/mobilenet_nchw.js

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
66
export class MobileNetV2Nchw {
77
constructor(dataType = 'float32') {
88
this.context_ = null;
9-
this.deviceType_ = null;
109
this.builder_ = null;
1110
this.graph_ = null;
1211
this.inputTensor_ = null;
@@ -89,7 +88,6 @@ export class MobileNetV2Nchw {
8988

9089
async load(contextOptions) {
9190
this.context_ = await navigator.ml.createContext(contextOptions);
92-
this.deviceType_ = contextOptions.deviceType;
9391
this.builder_ = new MLGraphBuilder(this.context_);
9492
const inputDesc = {
9593
dataType: 'float32',
@@ -107,6 +105,7 @@ export class MobileNetV2Nchw {
107105
usage: MLTensorUsage.READ,
108106
readable: true,
109107
});
108+
110109
if (this.dataType_ === 'float16') {
111110
data = this.builder_.cast(data, 'float16');
112111
}

0 commit comments

Comments
 (0)