Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions common/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -435,24 +435,27 @@ export function permuteData(array, dims, axes) {
export async function getDefaultLayout(deviceType) {
const context = await navigator.ml.createContext({deviceType});
const limits = context.opSupportLimits();
return limits.preferredInputLayout ?? 'nchw';

const preferredLayout = limits.preferredInputLayout ?? 'nchw';
context.destroy();
return preferredLayout;
}

/**
* Display available models based on device type and data type.
* @param {Object} modelList list of available models.
* @param {Array} modelIds list of model ids.
* @param {String} deviceType 'cpu', 'gpu' or 'npu'.
* @param {String} layout 'nchw' or 'nhwc'.
* @param {String} dataType 'float32', 'float16', or ''.
*/
export function displayAvailableModels(
modelList, modelIds, deviceType, dataType) {
export function displayAvailableModels(modelList, modelIds, layout, dataType) {
let models = [];
if (dataType == '') {
models = models.concat(modelList[deviceType]['float32']);
models = models.concat(modelList[deviceType]['float16']);
models = models.concat(modelList[layout]['float32'] ?? []);
models = models.concat(modelList[layout]['float16'] ?? []);
models = models.concat(modelList[layout]['uint8'] ?? []);
} else {
models = models.concat(modelList[deviceType][dataType]);
models = models.concat(modelList[layout][dataType] ?? []);
}
// Remove duplicate ids.
models = [...new Set(models)];
Expand Down
2 changes: 0 additions & 2 deletions facial_landmark_detection/ssd_mobilenetv2_face_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
export class SsdMobilenetV2FaceNchw {
constructor() {
this.context_ = null;
this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
Expand Down Expand Up @@ -117,7 +116,6 @@ ${nameArray[1]}`;

async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
const inputDesc = {
dataType: 'float32',
Expand Down
2 changes: 0 additions & 2 deletions facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
export class SsdMobilenetV2FaceNhwc {
constructor() {
this.context_ = null;
this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
Expand Down Expand Up @@ -128,7 +127,6 @@ ${nameArray[1]}`;

async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
const inputDesc = {
dataType: 'float32',
Expand Down
14 changes: 7 additions & 7 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
<span>Backend</span>
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<div class="btn-group-toggle" data-toggle="buttons" id="deviceTypeBtns">
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
<input type="radio" name="deviceType" id="cpu" autocomplete="off">WebNN (CPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off">WebNN (GPU)
<input type="radio" name="deviceType" id="gpu" autocomplete="off">WebNN (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_npu" autocomplete="off">WebNN (NPU)
<input type="radio" name="deviceType" id="npu" autocomplete="off">WebNN (NPU)
</label>
</div>
</div>
Expand All @@ -65,13 +65,13 @@
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="dataTypeBtns">
<label class="btn btn-outline-info" id="float32Label" active>
<input type="radio" name="layout" id="float32" autocomplete="off" checked>Float32
<input type="radio" name="dataType" id="float32" autocomplete="off" checked>Float32
</label>
<label class="btn btn-outline-info" id="float16Label">
<input type="radio" name="layout" id="float16" autocomplete="off">Float16
<input type="radio" name="dataType" id="float16" autocomplete="off">Float16
</label>
<label class="btn btn-outline-info" id="uint8Label">
<input type="radio" name="layout" id="uint8" autocomplete="off">Uint8
<input type="radio" name="dataType" id="uint8" autocomplete="off">Uint8
</label>
</div>
</div>
Expand Down
159 changes: 76 additions & 83 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ const imgElement = document.getElementById('feedElement');
imgElement.src = './images/test.jpg';
const camElement = document.getElementById('feedMediaElement');
let modelName = '';
let modelId = '';
let layout = 'nhwc';
let layout = '';
let dataType = 'float32';
let instanceType = modelName + layout;
let instanceType = '';
let rafReq;
let isFirstTimeLoad = true;
let inputType = 'image';
let netInstance = null;
let labels = null;
Expand All @@ -33,48 +31,39 @@ let buildTime = 0;
let computeTime = 0;
let inputOptions;
let deviceType = '';
let lastdeviceType = '';
let backend = '';
let lastBackend = '';
let stopRender = true;
let isRendering = false;
const disabledSelectors = ['#tabs > li', '.btn'];
const modelIds = [
'efficientnet',
'mobilenet',
'squeezenet',
'resnet50v2',
'resnet50v1',
'efficientnet',
'resnet50v2',
'squeezenet',
];
const modelList = {
'cpu': {
'nhwc': {
'float32': [
'mobilenet',
'squeezenet',
'resnet50v2',
],
'uint8': [
'mobilenet',
],
},
'gpu': {
'float32': [
'float16': [
'mobilenet',
'squeezenet',
'resnet50v2',
],
'float16': [
'efficientnet',
'uint8': [
'mobilenet',
'resnet50v1',
],
},
'npu': {
'float16': [
'efficientnet',
'nchw': {
'float32': [
'mobilenet',
'resnet50v1',
'squeezenet',
'resnet50v2',
],
'float16': modelIds,
},
};

Expand All @@ -87,58 +76,44 @@ async function fetchLabels(url) {
$(document).ready(async () => {
$('.icdisplay').hide();
if (await utils.isWebNN()) {
$('#webnn_cpu').click();
$('#cpu').click();
} else {
console.log(utils.webNNNotSupportMessage());
ui.addAlert(utils.webNNNotSupportMessageHTML());
}
layout = await utils.getDefaultLayout('cpu');
});

$('#backendBtns .btn').on('change', async (e) => {
$('#deviceTypeBtns .btn').on('change', async (e) => {
if (inputType === 'camera') {
await stopCamRender();
}
[backend, deviceType] = $(e.target).attr('id').split('_');
deviceType = $(e.target).attr('id');
layout = await utils.getDefaultLayout(deviceType);
// Only show the supported models for each deviceType. Now fp16 nchw models
// are only supported on gpu/npu.
if (deviceType == 'gpu') {
ui.handleBtnUI('#float16Label', false);
ui.handleBtnUI('#float32Label', false);
ui.handleBtnUI('#uint8Label', true);
$('#float32').click();
utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
} else if (deviceType == 'npu') {
ui.handleBtnUI('#float16Label', false);
const showUint8 = layout === 'nhwc' ? true : false;
ui.handleBtnUI('#uint8Label', !showUint8);
ui.handleBtnUI('#float16Label', false);
// Only show the supported models for each deviceType.
if (deviceType == 'npu') {
ui.handleBtnUI('#float32Label', true);
ui.handleBtnUI('#uint8Label', true);
$('#float16').click();
utils.displayAvailableModels(modelList, modelIds, deviceType, 'float16');
} else {
ui.handleBtnUI('#float16Label', true);
ui.handleBtnUI('#float32Label', false);
ui.handleBtnUI('#uint8Label', false);
$('#float32').click();
utils.displayAvailableModels(modelList, modelIds, deviceType, 'float32');
}

utils.displayAvailableModels(modelList, modelIds, layout, dataType);
// Uncheck selected model
if (modelId != '') {
$(`#${modelId}`).parent().removeClass('active');
if (modelName != '') {
$(`#${modelName}`).parent().removeClass('active');
}
});

$('#modelBtns .btn').on('change', async (e) => {
if (inputType === 'camera') {
await stopCamRender();
}
modelId = $(e.target).attr('id');
modelName = modelId;
if (dataType == 'float16') {
modelName += 'fp16';
} else if (dataType == 'uint8') {
modelName += 'uint8';
}
modelName = $(e.target).attr('id');

await main();
});
Expand All @@ -152,11 +127,15 @@ $('#modelBtns .btn').on('change', async (e) => {
// });

$('#dataTypeBtns .btn').on('change', async (e) => {
if (inputType === 'camera') {
await stopCamRender();
}

dataType = $(e.target).attr('id');
utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
utils.displayAvailableModels(modelList, modelIds, layout, dataType);
// Uncheck selected model
if (modelId != '') {
$(`#${modelId}`).parent().removeClass('active');
if (modelName != '') {
$(`#${modelName}`).parent().removeClass('active');
}
});

Expand Down Expand Up @@ -299,50 +278,64 @@ function showPerfResult(medianComputeTime = undefined) {
}
}

function constructNetObject(type) {
const netObject = {
'mobilenetfp16nchw': new MobileNetV2Nchw('float16'),
'resnet50v1fp16nchw': new ResNet50V1FP16Nchw(),
'efficientnetfp16nchw': new EfficientNetFP16Nchw(),
'mobilenetnchw': new MobileNetV2Nchw(),
'mobilenetnhwc': new MobileNetV2Nhwc(),
'mobilenetuint8nhwc': new MobileNetV2Uint8Nhwc(),
'squeezenetnchw': new SqueezeNetNchw(),
'squeezenetnhwc': new SqueezeNetNhwc(),
'resnet50v2nchw': new ResNet50V2Nchw(),
'resnet50v2nhwc': new ResNet50V2Nhwc(),
};
function constructNetObject(modelName, layout, dataType) {
switch (modelName) {
case 'efficientnet':
if (layout == 'nchw' && dataType == 'float16') {
return new EfficientNetFP16Nchw();
}
break;
case 'mobilenet':
if (layout == 'nhwc' && dataType == 'uint8') {
return new MobileNetV2Uint8Nhwc();
} else if (dataType != 'uint8') {
return layout == 'nhwc' ?
new MobileNetV2Nhwc(dataType) : new MobileNetV2Nchw(dataType);
}
break;
case 'resnet50v1':
if (layout == 'nchw' && dataType == 'float16') {
return new ResNet50V1FP16Nchw();
}
break;
case 'resnet50v2':
if (dataType != 'uint8') {
return layout == 'nhwc' ?
new ResNet50V2Nhwc(dataType) : new ResNet50V2Nchw(dataType);
}
break;
case 'squeezenet':
if (dataType != 'uint8') {
return layout == 'nhwc' ?
new SqueezeNetNhwc() : new SqueezeNetNchw();
}
break;
}

return netObject[type];
throw new Error(`Unknown model, name: ${modelName}, layout: ${layout}, ` +
`dataType: ${dataType}`);
}

async function main() {
try {
if (modelName === '') return;
ui.handleClick(disabledSelectors, true);
if (isFirstTimeLoad) $('#hint').hide();
if (instanceType == '') $('#hint').hide();
let start;
const [numRuns, powerPreference] = utils.getUrlParams();

// Only do load() and build() when model first time loads,
// there's new model choosed, backend changed or device changed
if (isFirstTimeLoad || instanceType !== modelName + layout ||
lastdeviceType != deviceType || lastBackend != backend) {
if (lastdeviceType != deviceType || lastBackend != backend) {
// Set backend and device
lastdeviceType = lastdeviceType != deviceType ?
deviceType : lastdeviceType;
lastBackend = lastBackend != backend ? backend : lastBackend;
}
instanceType = modelName + layout;
netInstance = constructNetObject(instanceType);
// there's new model choosed
if (instanceType !== modelName + dataType + layout + deviceType) {
instanceType = modelName + dataType + layout + deviceType;
netInstance = constructNetObject(modelName, layout, dataType);
inputOptions = netInstance.inputOptions;
labels = await fetchLabels(inputOptions.labelUrl);
isFirstTimeLoad = false;
console.log(`- Model name: ${modelName}, Model layout: ${layout} -`);
console.log(
`- Model: ${modelName} - ${layout} - ${dataType} - ${deviceType}`);
// UI shows model loading progress
await ui.showProgressComponent('current', 'pending', 'pending');
console.log('- Loading weights... ');
console.log('- Loading weights...');
const contextOptions = {deviceType};
if (powerPreference) {
contextOptions['powerPreference'] = powerPreference;
Expand Down
3 changes: 1 addition & 2 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
export class MobileNetV2Nchw {
constructor(dataType = 'float32') {
this.context_ = null;
this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
Expand Down Expand Up @@ -89,7 +88,6 @@ export class MobileNetV2Nchw {

async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
const inputDesc = {
dataType: 'float32',
Expand All @@ -107,6 +105,7 @@ export class MobileNetV2Nchw {
usage: MLTensorUsage.READ,
readable: true,
});

if (this.dataType_ === 'float16') {
data = this.builder_.cast(data, 'float16');
}
Expand Down
Loading