diff --git a/common/utils.js b/common/utils.js
index 46e097c2..9d29c3a8 100644
--- a/common/utils.js
+++ b/common/utils.js
@@ -435,24 +435,27 @@ export function permuteData(array, dims, axes) {
export async function getDefaultLayout(deviceType) {
const context = await navigator.ml.createContext({deviceType});
const limits = context.opSupportLimits();
- return limits.preferredInputLayout ?? 'nchw';
+
+ const preferredLayout = limits.preferredInputLayout ?? 'nchw';
+ context.destroy();
+ return preferredLayout;
}
/**
* Display available models based on device type and data type.
* @param {Object} modelList list of available models.
* @param {Array} modelIds list of model ids.
- * @param {String} deviceType 'cpu', 'gpu' or 'npu'.
+ * @param {String} layout 'nchw' or 'nhwc'.
* @param {String} dataType 'float32', 'float16', or ''.
*/
-export function displayAvailableModels(
- modelList, modelIds, deviceType, dataType) {
+export function displayAvailableModels(modelList, modelIds, layout, dataType) {
let models = [];
if (dataType == '') {
- models = models.concat(modelList[deviceType]['float32']);
- models = models.concat(modelList[deviceType]['float16']);
+ models = models.concat(modelList[layout]['float32'] ?? []);
+ models = models.concat(modelList[layout]['float16'] ?? []);
+ models = models.concat(modelList[layout]['uint8'] ?? []);
} else {
- models = models.concat(modelList[deviceType][dataType]);
+ models = models.concat(modelList[layout][dataType] ?? []);
}
// Remove duplicate ids.
models = [...new Set(models)];
diff --git a/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js b/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js
index 76ad51f1..2f70d1bb 100644
--- a/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js
+++ b/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js
@@ -6,7 +6,6 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
export class SsdMobilenetV2FaceNchw {
constructor() {
this.context_ = null;
- this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
@@ -117,7 +116,6 @@ ${nameArray[1]}`;
async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
- this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
const inputDesc = {
dataType: 'float32',
diff --git a/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js b/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js
index dadf5a95..58305b1b 100644
--- a/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js
+++ b/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js
@@ -6,7 +6,6 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
export class SsdMobilenetV2FaceNhwc {
constructor() {
this.context_ = null;
- this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
@@ -128,7 +127,6 @@ ${nameArray[1]}`;
async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
- this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
const inputDesc = {
dataType: 'float32',
diff --git a/image_classification/index.html b/image_classification/index.html
index 496a318f..bb3e9cf7 100644
--- a/image_classification/index.html
+++ b/image_classification/index.html
@@ -30,15 +30,15 @@
Backend
-
@@ -65,13 +65,13 @@
diff --git a/image_classification/main.js b/image_classification/main.js
index 29767894..d6ce9d07 100644
--- a/image_classification/main.js
+++ b/image_classification/main.js
@@ -18,12 +18,10 @@ const imgElement = document.getElementById('feedElement');
imgElement.src = './images/test.jpg';
const camElement = document.getElementById('feedMediaElement');
let modelName = '';
-let modelId = '';
-let layout = 'nhwc';
+let layout = '';
let dataType = 'float32';
-let instanceType = modelName + layout;
+let instanceType = '';
let rafReq;
-let isFirstTimeLoad = true;
let inputType = 'image';
let netInstance = null;
let labels = null;
@@ -33,48 +31,39 @@ let buildTime = 0;
let computeTime = 0;
let inputOptions;
let deviceType = '';
-let lastdeviceType = '';
-let backend = '';
-let lastBackend = '';
let stopRender = true;
let isRendering = false;
const disabledSelectors = ['#tabs > li', '.btn'];
const modelIds = [
+ 'efficientnet',
'mobilenet',
- 'squeezenet',
- 'resnet50v2',
'resnet50v1',
- 'efficientnet',
+ 'resnet50v2',
+ 'squeezenet',
];
const modelList = {
- 'cpu': {
+ 'nhwc': {
'float32': [
'mobilenet',
'squeezenet',
'resnet50v2',
],
- 'uint8': [
- 'mobilenet',
- ],
- },
- 'gpu': {
- 'float32': [
+ 'float16': [
'mobilenet',
'squeezenet',
'resnet50v2',
],
- 'float16': [
- 'efficientnet',
+ 'uint8': [
'mobilenet',
- 'resnet50v1',
],
},
- 'npu': {
- 'float16': [
- 'efficientnet',
+ 'nchw': {
+ 'float32': [
'mobilenet',
- 'resnet50v1',
+ 'squeezenet',
+ 'resnet50v2',
],
+ 'float16': modelIds,
},
};
@@ -87,44 +76,36 @@ async function fetchLabels(url) {
$(document).ready(async () => {
$('.icdisplay').hide();
if (await utils.isWebNN()) {
- $('#webnn_cpu').click();
+ $('#cpu').click();
} else {
console.log(utils.webNNNotSupportMessage());
ui.addAlert(utils.webNNNotSupportMessageHTML());
}
+ layout = await utils.getDefaultLayout('cpu');
});
-$('#backendBtns .btn').on('change', async (e) => {
+$('#deviceTypeBtns .btn').on('change', async (e) => {
if (inputType === 'camera') {
await stopCamRender();
}
- [backend, deviceType] = $(e.target).attr('id').split('_');
+ deviceType = $(e.target).attr('id');
layout = await utils.getDefaultLayout(deviceType);
- // Only show the supported models for each deviceType. Now fp16 nchw models
- // are only supported on gpu/npu.
- if (deviceType == 'gpu') {
- ui.handleBtnUI('#float16Label', false);
- ui.handleBtnUI('#float32Label', false);
- ui.handleBtnUI('#uint8Label', true);
- $('#float32').click();
- utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
- } else if (deviceType == 'npu') {
- ui.handleBtnUI('#float16Label', false);
+ const showUint8 = layout === 'nhwc' ? true : false;
+ ui.handleBtnUI('#uint8Label', !showUint8);
+ ui.handleBtnUI('#float16Label', false);
+ // Only show the supported models for each deviceType.
+ if (deviceType == 'npu') {
ui.handleBtnUI('#float32Label', true);
- ui.handleBtnUI('#uint8Label', true);
$('#float16').click();
- utils.displayAvailableModels(modelList, modelIds, deviceType, 'float16');
} else {
- ui.handleBtnUI('#float16Label', true);
ui.handleBtnUI('#float32Label', false);
- ui.handleBtnUI('#uint8Label', false);
$('#float32').click();
- utils.displayAvailableModels(modelList, modelIds, deviceType, 'float32');
}
+ utils.displayAvailableModels(modelList, modelIds, layout, dataType);
// Uncheck selected model
- if (modelId != '') {
- $(`#${modelId}`).parent().removeClass('active');
+ if (modelName != '') {
+ $(`#${modelName}`).parent().removeClass('active');
}
});
@@ -132,13 +113,7 @@ $('#modelBtns .btn').on('change', async (e) => {
if (inputType === 'camera') {
await stopCamRender();
}
- modelId = $(e.target).attr('id');
- modelName = modelId;
- if (dataType == 'float16') {
- modelName += 'fp16';
- } else if (dataType == 'uint8') {
- modelName += 'uint8';
- }
+ modelName = $(e.target).attr('id');
await main();
});
@@ -152,11 +127,15 @@ $('#modelBtns .btn').on('change', async (e) => {
// });
$('#dataTypeBtns .btn').on('change', async (e) => {
+ if (inputType === 'camera') {
+ await stopCamRender();
+ }
+
dataType = $(e.target).attr('id');
- utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
+ utils.displayAvailableModels(modelList, modelIds, layout, dataType);
// Uncheck selected model
- if (modelId != '') {
- $(`#${modelId}`).parent().removeClass('active');
+ if (modelName != '') {
+ $(`#${modelName}`).parent().removeClass('active');
}
});
@@ -299,50 +278,64 @@ function showPerfResult(medianComputeTime = undefined) {
}
}
-function constructNetObject(type) {
- const netObject = {
- 'mobilenetfp16nchw': new MobileNetV2Nchw('float16'),
- 'resnet50v1fp16nchw': new ResNet50V1FP16Nchw(),
- 'efficientnetfp16nchw': new EfficientNetFP16Nchw(),
- 'mobilenetnchw': new MobileNetV2Nchw(),
- 'mobilenetnhwc': new MobileNetV2Nhwc(),
- 'mobilenetuint8nhwc': new MobileNetV2Uint8Nhwc(),
- 'squeezenetnchw': new SqueezeNetNchw(),
- 'squeezenetnhwc': new SqueezeNetNhwc(),
- 'resnet50v2nchw': new ResNet50V2Nchw(),
- 'resnet50v2nhwc': new ResNet50V2Nhwc(),
- };
+function constructNetObject(modelName, layout, dataType) {
+ switch (modelName) {
+ case 'efficientnet':
+ if (layout == 'nchw' && dataType == 'float16') {
+ return new EfficientNetFP16Nchw();
+ }
+ break;
+ case 'mobilenet':
+ if (layout == 'nhwc' && dataType == 'uint8') {
+ return new MobileNetV2Uint8Nhwc();
+ } else if (dataType != 'uint8') {
+ return layout == 'nhwc' ?
+ new MobileNetV2Nhwc(dataType) : new MobileNetV2Nchw(dataType);
+ }
+ break;
+ case 'resnet50v1':
+ if (layout == 'nchw' && dataType == 'float16') {
+ return new ResNet50V1FP16Nchw();
+ }
+ break;
+ case 'resnet50v2':
+ if (dataType != 'uint8') {
+ return layout == 'nhwc' ?
+ new ResNet50V2Nhwc(dataType) : new ResNet50V2Nchw(dataType);
+ }
+ break;
+ case 'squeezenet':
+ if (dataType != 'uint8') {
+ return layout == 'nhwc' ?
+ new SqueezeNetNhwc() : new SqueezeNetNchw();
+ }
+ break;
+ }
- return netObject[type];
+ throw new Error(`Unknown model, name: ${modelName}, layout: ${layout}, ` +
+ `dataType: ${dataType}`);
}
async function main() {
try {
if (modelName === '') return;
ui.handleClick(disabledSelectors, true);
- if (isFirstTimeLoad) $('#hint').hide();
+ if (instanceType == '') $('#hint').hide();
let start;
const [numRuns, powerPreference] = utils.getUrlParams();
// Only do load() and build() when model first time loads,
- // there's new model choosed, backend changed or device changed
- if (isFirstTimeLoad || instanceType !== modelName + layout ||
- lastdeviceType != deviceType || lastBackend != backend) {
- if (lastdeviceType != deviceType || lastBackend != backend) {
- // Set backend and device
- lastdeviceType = lastdeviceType != deviceType ?
- deviceType : lastdeviceType;
- lastBackend = lastBackend != backend ? backend : lastBackend;
- }
- instanceType = modelName + layout;
- netInstance = constructNetObject(instanceType);
+ // there's new model choosed
+ if (instanceType !== modelName + dataType + layout + deviceType) {
+ instanceType = modelName + dataType + layout + deviceType;
+ netInstance = constructNetObject(modelName, layout, dataType);
inputOptions = netInstance.inputOptions;
labels = await fetchLabels(inputOptions.labelUrl);
- isFirstTimeLoad = false;
- console.log(`- Model name: ${modelName}, Model layout: ${layout} -`);
+ console.log(
+ `- Model: ${modelName} - ${layout} - ${dataType} - ${deviceType}`);
// UI shows model loading progress
await ui.showProgressComponent('current', 'pending', 'pending');
- console.log('- Loading weights... ');
+ console.log('- Loading weights...');
const contextOptions = {deviceType};
if (powerPreference) {
contextOptions['powerPreference'] = powerPreference;
diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js
index b14d7a0b..a5f91b06 100644
--- a/image_classification/mobilenet_nchw.js
+++ b/image_classification/mobilenet_nchw.js
@@ -6,7 +6,6 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
export class MobileNetV2Nchw {
constructor(dataType = 'float32') {
this.context_ = null;
- this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
@@ -89,7 +88,6 @@ export class MobileNetV2Nchw {
async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
- this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
const inputDesc = {
dataType: 'float32',
@@ -107,6 +105,7 @@ export class MobileNetV2Nchw {
usage: MLTensorUsage.READ,
readable: true,
});
+
if (this.dataType_ === 'float16') {
data = this.builder_.cast(data, 'float16');
}
diff --git a/image_classification/mobilenet_nhwc.js b/image_classification/mobilenet_nhwc.js
index 733e225a..a0a4a353 100644
--- a/image_classification/mobilenet_nhwc.js
+++ b/image_classification/mobilenet_nhwc.js
@@ -6,13 +6,13 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
// MobileNet V2 model with 'nhwc' input layout
export class MobileNetV2Nhwc {
- constructor() {
+ constructor(dataType = 'float32') {
this.context_ = null;
- this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
this.outputTensor_ = null;
+ this.targetDataType_ = dataType;
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/mobilenetv2_nhwc/weights/';
this.inputOptions = {
@@ -27,9 +27,10 @@ export class MobileNetV2Nhwc {
async buildConv_(input, weightsSubName, biasSubName, relu6, options) {
const weightsName = this.weightsUrl_ + 'Const_' + weightsSubName + '.npy';
- const weights = await buildConstantByNpy(this.builder_, weightsName);
+ const weights = await buildConstantByNpy(
+ this.builder_, weightsName, this.targetDataType_);
const biasName = this.weightsUrl_ + 'MobilenetV2_' + biasSubName + '_bias.npy';
- const bias = buildConstantByNpy(this.builder_, biasName);
+ const bias = buildConstantByNpy(this.builder_, biasName, this.targetDataType_);
options.inputLayout = 'nhwc';
options.bias = await bias;
// WebNN spec drops autoPad support, compute the explicit padding instead.
@@ -87,7 +88,6 @@ export class MobileNetV2Nhwc {
async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
- this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
const strides = [2, 2];
const autoPad = 'same-upper';
@@ -97,7 +97,7 @@ export class MobileNetV2Nhwc {
dimensions: this.inputOptions.inputShape,
shape: this.inputOptions.inputShape,
};
- const input = this.builder_.input('input', inputDesc);
+ let input = this.builder_.input('input', inputDesc);
inputDesc.usage = MLTensorUsage.WRITE;
inputDesc.writable = true;
this.inputTensor_ = await this.context_.createTensor(inputDesc);
@@ -108,6 +108,10 @@ export class MobileNetV2Nhwc {
usage: MLTensorUsage.READ,
readable: true,
});
+
+ if (this.targetDataType_ === 'float16') {
+ input = this.builder_.cast(input, 'float16');
+ }
const conv0 = this.buildConv_(
input, '90', 'Conv_Conv2D', true, {strides, autoPad, filterLayout});
const conv1 = this.buildConv_(
@@ -153,7 +157,12 @@ export class MobileNetV2Nhwc {
const conv4 = this.buildConv_(
averagePool2d, '222', 'Logits_Conv2d_1c_1x1_Conv2D', false, {autoPad, filterLayout});
const reshape = this.builder_.reshape(await conv4, [1, 1001]);
- return await this.builder_.softmax(reshape, 1);
+ const softmax = await this.builder_.softmax(reshape, 1);
+
+ if (this.targetDataType_ === 'float16') {
+ return this.builder_.cast(softmax, 'float32');
+ }
+ return softmax;
}
async build(outputOperand) {
diff --git a/image_classification/mobilenet_uint8_nhwc.js b/image_classification/mobilenet_uint8_nhwc.js
index 4f39c521..d4c2e90c 100644
--- a/image_classification/mobilenet_uint8_nhwc.js
+++ b/image_classification/mobilenet_uint8_nhwc.js
@@ -31,15 +31,15 @@ export class MobileNetV2Uint8Nhwc {
quantizateParams.shape.push(...Array(missingDims).fill(1));
}
- const scale = this.builder_.constant( {dataType: 'float32', shape: quantizateParams.shape},
+ const scale = this.builder_.constant( {dataType: 'float32', shape: quantizateParams.shape},
new Float32Array(quantizateParams.scale));
let zeroPoint;
if (dataType === 'uint8') {
zeroPoint = this.builder_.constant( {dataType: 'uint8', shape: quantizateParams.shape},
- new Uint8Array(quantizateParams.zero_point));
+ new Uint8Array(quantizateParams.zero_point));
} else if (dataType === 'int32') {
zeroPoint = this.builder_.constant( {dataType: 'int32', shape: quantizateParams.shape},
- new Int32Array(quantizateParams.zero_point));
+ new Int32Array(quantizateParams.zero_point));
} else {
throw new Error(`Data type ${dataType} is not supported.`);
}
@@ -53,10 +53,10 @@ export class MobileNetV2Uint8Nhwc {
quantizateParams.shape.push(...Array(missingDims).fill(1));
}
- const scale = this.builder_.constant( {dataType: 'float32', shape: quantizateParams.shape},
+ const scale = this.builder_.constant( {dataType: 'float32', shape: quantizateParams.shape},
new Float32Array(quantizateParams.scale));
const zeroPoint = this.builder_.constant( {dataType: 'uint8', shape: quantizateParams.shape},
- new Uint8Array(quantizateParams.zero_point));
+ new Uint8Array(quantizateParams.zero_point));
return this.builder_.quantizeLinear(input, scale, zeroPoint);
}
@@ -141,57 +141,57 @@ export class MobileNetV2Uint8Nhwc {
readable: true,
});
const conv0 = await this.buildConv_(
- input, '55',
+ input, '55',
{scale: [0.034375786781311035], zero_point: [159], shape: []},
{scale: [0.0002706754894461483], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
true, {strides, autoPad, filterLayout});
const conv1 = await this.buildConv_(
- conv0, '57',
+ conv0, '57',
{scale: [0.5174643397331238], zero_point: [115], shape: []},
- {scale: [0.012175869196653366], zero_point: [0], shape: []},
+ {scale: [0.012175869196653366], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
true, {autoPad, groups: 32, filterLayout: 'ihwo'});
const conv2 = await this.buildConv_(
- conv1, '59',
+ conv1, '59',
{scale: [0.06309328228235245], zero_point: [90], shape: []},
- {scale: [0.0014845768455415964], zero_point: [0], shape: []},
+ {scale: [0.0014845768455415964], zero_point: [0], shape: []},
{scale: [0.3935633599758148], zero_point: [129], shape: []},
false, {autoPad, filterLayout});
const bottleneck0 = await this.buildLinearBottleneck_(
- conv2, ['61', '63', '65'],
+ conv2, ['61', '63', '65'],
[
{scale: [0.008153429254889488], zero_point: [85], shape: []},
{scale: [0.01082384679466486], zero_point: [118], shape: []},
- {scale: [0.03367125615477562], zero_point: [152], shape: []}
+ {scale: [0.03367125615477562], zero_point: [152], shape: []},
],
[
{scale: [0.003208891022950411], zero_point: [0], shape: []},
{scale: [0.00025468372041359544], zero_point: [0], shape: []},
- {scale: [0.0007922803633846343], zero_point: [0], shape: []}
+ {scale: [0.0007922803633846343], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.02352987229824066 ], zero_point: [0], shape: []},
- {scale: [0.32650214433670044], zero_point: [117], shape: []}
+ {scale: [0.02352987229824066], zero_point: [0], shape: []},
+ {scale: [0.32650214433670044], zero_point: [117], shape: []},
],
{}, {strides, groups: 96}, false);
const bottleneck1 = await this.buildLinearBottleneck_(
- bottleneck0, ['67', '69', '71'],
+ bottleneck0, ['67', '69', '71'],
[
{scale: [0.003573313821107149], zero_point: [102], shape: []},
{scale: [0.14301884174346924], zero_point: [166], shape: []},
- {scale: [0.0644076019525528], zero_point: [122], shape: []}
+ {scale: [0.0644076019525528], zero_point: [122], shape: []},
],
[
{scale: [0.0011666945647448301], zero_point: [0], shape: []},
{scale: [0.003365214914083481], zero_point: [0], shape: []},
- {scale: [0.0015155026922002435], zero_point: [0], shape: []}
+ {scale: [0.0015155026922002435], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.6063596606254578], zero_point: [111], shape: []}
+ {scale: [0.6063596606254578], zero_point: [111], shape: []},
],
{scale: [0.609800398349762], zero_point: [117], shape: []},
{groups: 144});
@@ -200,92 +200,92 @@ export class MobileNetV2Uint8Nhwc {
[
{scale: [0.002144400728866458], zero_point: [129], shape: []},
{scale: [0.026209063827991486], zero_point: [141], shape: []},
- {scale: [0.030137652531266212], zero_point: [160], shape: []}
+ {scale: [0.030137652531266212], zero_point: [160], shape: []},
],
[
{scale: [0.0013076564064249396], zero_point: [0], shape: []},
{scale: [0.0006166959065012634], zero_point: [0], shape: []},
- {scale: [0.0007091350853443146], zero_point: [0], shape: []}
+ {scale: [0.0007091350853443146], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.28959763050079346], zero_point: [132], shape: []}
+ {scale: [0.28959763050079346], zero_point: [132], shape: []},
],
{}, {strides, groups: 144}, false);
const bottleneck3 = await this.buildLinearBottleneck_(
- bottleneck2, ['80', '82', '84'],
+ bottleneck2, ['80', '82', '84'],
[
{scale: [0.0018756906501948833], zero_point: [129], shape: []},
{scale: [0.04708000645041466], zero_point: [107], shape: []},
- {scale: [0.021888649091124535], zero_point: [144], shape: []}
+ {scale: [0.021888649091124535], zero_point: [144], shape: []},
],
[
{scale: [0.0005431955796666443], zero_point: [0], shape: []},
{scale: [0.0011077865492552519], zero_point: [0], shape: []},
- {scale: [0.0005150370998308063], zero_point: [0], shape: []}
+ {scale: [0.0005150370998308063], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.2785290479660034], zero_point: [144], shape: []}
+ {scale: [0.2785290479660034], zero_point: [144], shape: []},
],
{scale: [0.3821489214897156], zero_point: [134], shape: []},
{groups: 192});
const bottleneck4 = await this.buildLinearBottleneck_(
- bottleneck3, ['87', '89', '91'],
+ bottleneck3, ['87', '89', '91'],
[
- {scale: [0.001845767954364419 ], zero_point: [144], shape: []},
- {scale: [0.056680627167224884 ], zero_point: [138], shape: []},
- {scale: [0.027344657108187675], zero_point: [141], shape: []}
+ {scale: [0.001845767954364419], zero_point: [144], shape: []},
+ {scale: [0.056680627167224884], zero_point: [138], shape: []},
+ {scale: [0.027344657108187675], zero_point: [141], shape: []},
],
[
{scale: [0.0007053582230582833], zero_point: [0], shape: []},
{scale: [0.0013336879201233387], zero_point: [0], shape: []},
- {scale: [0.0006434162496589124], zero_point: [0], shape: []}
+ {scale: [0.0006434162496589124], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.31692883372306824], zero_point: [130], shape: []}
+ {scale: [0.31692883372306824], zero_point: [130], shape: []},
],
{scale: [0.4749276041984558], zero_point: [122], shape: []},
{groups: 192});
const bottleneck5 = await this.buildLinearBottleneck_(
- bottleneck4, ['94', '96', '98'],
+ bottleneck4, ['94', '96', '98'],
[
{scale: [0.0021686102263629436], zero_point: [147], shape: []},
{scale: [0.01276324037462473], zero_point: [141], shape: []},
- {scale: [0.01878936029970646], zero_point: [145], shape: []}
+ {scale: [0.01878936029970646], zero_point: [145], shape: []},
],
[
{scale: [0.0010299327550455928], zero_point: [0], shape: []},
{scale: [0.00030031739152036607], zero_point: [0], shape: []},
- {scale: [0.0004421112244017422], zero_point: [0], shape: []}
+ {scale: [0.0004421112244017422], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.2426074892282486], zero_point: [126], shape: []}
+ {scale: [0.2426074892282486], zero_point: [126], shape: []},
],
{}, {strides, groups: 192}, false);
const bottleneck6 = await this.buildLinearBottleneck_(
- bottleneck5, ['100', '102', '104'],
+ bottleneck5, ['100', '102', '104'],
[
{scale: [0.0018693081801757216], zero_point: [124], shape: []},
{scale: [0.057145655155181885], zero_point: [131], shape: []},
- {scale: [0.024178611114621162], zero_point: [173], shape: []}
+ {scale: [0.024178611114621162], zero_point: [173], shape: []},
],
[
{scale: [0.0004535081679932773], zero_point: [0], shape: []},
{scale: [0.0013446299126371741], zero_point: [0], shape: []},
- {scale: [0.0005689196405000985], zero_point: [0], shape: []}
+ {scale: [0.0005689196405000985], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.24107083678245544], zero_point: [99], shape: []}
- ],
+ {scale: [0.24107083678245544], zero_point: [99], shape: []},
+ ],
{scale: [0.272112101316452], zero_point: [122], shape: []},
{groups: 384});
const bottleneck7 = await this.buildLinearBottleneck_(
@@ -293,18 +293,18 @@ export class MobileNetV2Uint8Nhwc {
[
{scale: [0.0013072594301775098], zero_point: [139], shape: []},
{scale: [0.03875831514596939], zero_point: [143], shape: []},
- {scale: [0.021180255338549614], zero_point: [145], shape: []}
+ {scale: [0.021180255338549614], zero_point: [145], shape: []},
],
[
{scale: [0.0003557211020961404], zero_point: [0], shape: []},
{scale: [0.0009119781898334622], zero_point: [0], shape: []},
- {scale: [0.0004983686958439648], zero_point: [0], shape: []}
+ {scale: [0.0004983686958439648], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.21128740906715393], zero_point: [133], shape: []}
- ],
+ {scale: [0.21128740906715393], zero_point: [133], shape: []},
+ ],
{scale: [0.30632293224334717], zero_point: [119], shape: []},
{groups: 384});
const bottleneck8 = await this.buildLinearBottleneck_(
@@ -312,129 +312,129 @@ export class MobileNetV2Uint8Nhwc {
[
{scale: [0.0011219490552321076], zero_point: [138], shape: []},
{scale: [0.03533448651432991], zero_point: [107], shape: []},
- {scale: [0.025988703593611717 ], zero_point: [151], shape: []}
+ {scale: [0.025988703593611717], zero_point: [151], shape: []},
],
[
{scale: [0.0003436787228565663], zero_point: [0], shape: []},
{scale: [0.0008314159349538386], zero_point: [0], shape: []},
- {scale: [0.0006115108844824135], zero_point: [0], shape: []}
+ {scale: [0.0006115108844824135], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.2665506601333618], zero_point: [126], shape: []}
- ],
+ {scale: [0.2665506601333618], zero_point: [126], shape: []},
+ ],
{scale: [0.3084481954574585], zero_point: [130], shape: []},
{groups: 384});
const bottleneck9 = await this.buildLinearBottleneck_(
- bottleneck8, ['121', '123', '125'],
+ bottleneck8, ['121', '123', '125'],
[
{scale: [0.0015335703501477838], zero_point: [156], shape: []},
{scale: [0.02276834286749363], zero_point: [131], shape: []},
- {scale: [0.012576368637382984], zero_point: [100], shape: []}
+ {scale: [0.012576368637382984], zero_point: [100], shape: []},
],
[
{scale: [0.00047302700113505125], zero_point: [0], shape: []},
{scale: [0.0005357362097129226], zero_point: [0], shape: []},
- {scale: [0.0002959203557111323], zero_point: [0], shape: []}
+ {scale: [0.0002959203557111323], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.2215471863746643], zero_point: [126], shape: []}
- ],
+ {scale: [0.2215471863746643], zero_point: [126], shape: []},
+ ],
{}, {groups: 384}, false);
const bottleneck10 = await this.buildLinearBottleneck_(
bottleneck9, ['127', '129', '131'],
[
{scale: [0.0014903603587299585], zero_point: [110], shape: []},
{scale: [0.04933452978730202], zero_point: [131], shape: []},
- {scale: [0.012083801440894604], zero_point: [152], shape: []}
+ {scale: [0.012083801440894604], zero_point: [152], shape: []},
],
[
{scale: [0.0003301851393189281], zero_point: [0], shape: []},
{scale: [0.0011608351487666368], zero_point: [0], shape: []},
- {scale: [0.0002843302791006863], zero_point: [0], shape: []}
+ {scale: [0.0002843302791006863], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.1737687587738037], zero_point: [126], shape: []}
+ {scale: [0.1737687587738037], zero_point: [126], shape: []},
],
{scale: [0.23995822668075562], zero_point: [128], shape: []},
{groups: 576});
const bottleneck11 = await this.buildLinearBottleneck_(
- bottleneck10, ['134', '136', '138'],
+ bottleneck10, ['134', '136', '138'],
[
{scale: [0.0030131470412015915], zero_point: [142], shape: []},
{scale: [0.09067106992006302], zero_point: [107], shape: []},
- {scale: [0.01852469891309738], zero_point: [123], shape: []}
+ {scale: [0.01852469891309738], zero_point: [123], shape: []},
],
[
{scale: [0.000723029428627342], zero_point: [0], shape: []},
{scale: [0.002133478643372655], zero_point: [0], shape: []},
- {scale: [0.00043588379048742354], zero_point: [0], shape: []}
+ {scale: [0.00043588379048742354], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.2431039810180664], zero_point: [129], shape: []}
+ {scale: [0.2431039810180664], zero_point: [129], shape: []},
],
{scale: [0.3369702398777008], zero_point: [125], shape: []},
{groups: 576});
const bottleneck12 = await this.buildLinearBottleneck_(
bottleneck11, ['141', '143', '145'],
[
- {scale: [0.0016000346513465047 ], zero_point: [123], shape: []},
+ {scale: [0.0016000346513465047], zero_point: [123], shape: []},
{scale: [0.06790248304605484], zero_point: [69], shape: []},
- {scale: [0.01406034268438816], zero_point: [149], shape: []}
+ {scale: [0.01406034268438816], zero_point: [149], shape: []},
],
[
{scale: [0.0005391640588641167], zero_point: [0], shape: []},
{scale: [0.001597736612893641], zero_point: [0], shape: []},
- {scale: [0.0003308380546513945], zero_point: [0], shape: []}
+ {scale: [0.0003308380546513945], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.18441903591156006], zero_point: [137], shape: []}
+ {scale: [0.18441903591156006], zero_point: [137], shape: []},
],
{}, {strides, groups: 576}, false);
const bottleneck13 = await this.buildLinearBottleneck_(
- bottleneck12, ['147', '149', '151'],
+ bottleneck12, ['147', '149', '151'],
[
{scale: [0.002654522191733122], zero_point: [118], shape: []},
{scale: [0.0386493057012558], zero_point: [148], shape: []},
- {scale: [0.012022278271615505], zero_point: [152], shape: []}
+ {scale: [0.012022278271615505], zero_point: [152], shape: []},
],
[
{scale: [0.0004895444144494832], zero_point: [0], shape: []},
{scale: [0.0009094132692553103], zero_point: [0], shape: []},
- {scale: [0.0002828826545737684], zero_point: [0], shape: []}
+ {scale: [0.0002828826545737684], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.15274383127689362 ], zero_point: [117], shape: []}
+ {scale: [0.15274383127689362], zero_point: [117], shape: []},
],
{scale: [0.18867115676403046], zero_point: [125], shape: []},
{groups: 960});
const bottleneck14 = await this.buildLinearBottleneck_(
- bottleneck13, ['154', '156', '158'],
+ bottleneck13, ['154', '156', '158'],
[
{scale: [0.002872081007808447], zero_point: [171], shape: []},
{scale: [0.042505424469709396], zero_point: [174], shape: []},
- {scale: [0.07219446450471878], zero_point: [89], shape: []}
+ {scale: [0.07219446450471878], zero_point: [89], shape: []},
],
[
{scale: [0.0005418788641691208], zero_point: [0], shape: []},
{scale: [0.0010001471964642406], zero_point: [0], shape: []},
- {scale: [0.0016987264389172196], zero_point: [0], shape: []}
+ {scale: [0.0016987264389172196], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.46753087639808655], zero_point: [168], shape: []}
+ {scale: [0.46753087639808655], zero_point: [168], shape: []},
],
{scale: [0.5439051389694214], zero_point: [175], shape: []},
{groups: 960});
@@ -443,42 +443,42 @@ export class MobileNetV2Uint8Nhwc {
[
{scale: [0.0015946601051837206], zero_point: [115], shape: []},
{scale: [0.05400737002491951], zero_point: [154], shape: []},
- {scale: [0.025818506255745888], zero_point: [67], shape: []}
+ {scale: [0.025818506255745888], zero_point: [67], shape: []},
],
[
{scale: [0.0008673437987454236], zero_point: [0], shape: []},
{scale: [0.001270786509849131], zero_point: [0], shape: []},
- {scale: [0.0006075061392039061], zero_point: [0], shape: []}
+ {scale: [0.0006075061392039061], zero_point: [0], shape: []},
],
[
{scale: [0.02352987229824066], zero_point: [0], shape: []},
{scale: [0.02352987229824066], zero_point: [0], shape: []},
- {scale: [0.25913476943969727], zero_point: [176], shape: []}
+ {scale: [0.25913476943969727], zero_point: [176], shape: []},
],
{}, {groups: 960}, false);
-
+
const conv3 = await this.buildConv_(
- bottleneck15, '167',
- {scale: [0.003932334017008543], zero_point: [116], shape: []},
- {scale: [0.001019004499539733], zero_point: [0], shape: []},
- {scale: [0.02352987229824066], zero_point: [0], shape: []},
- true, {autoPad, filterLayout});
-
+ bottleneck15, '167',
+ {scale: [0.003932334017008543], zero_point: [116], shape: []},
+ {scale: [0.001019004499539733], zero_point: [0], shape: []},
+ {scale: [0.02352987229824066], zero_point: [0], shape: []},
+ true, {autoPad, filterLayout});
+
const poolQuantize = {scale: [0.02352987229824066], zero_point: [0], shape: []};
const averagePool2d = this.builder_.averagePool2d(
- conv3, {windowDimensions: [7, 7], layout: 'nhwc'});
-
+ conv3, {windowDimensions: [7, 7], layout: 'nhwc'});
+
const quantize2 = this.quantizeLinear_(averagePool2d, poolQuantize);
const dequantize = this.dequantizeLinear_(quantize2, poolQuantize, 'uint8');
const conv4 = await this.buildConv_(
- dequantize, '170',
- {scale: [0.002771724946796894], zero_point: [105], shape: []},
- {scale: [0.00006521832983708009], zero_point: [0], shape: []},
- {scale: [0.06046031787991524], zero_point: [60], shape: []},
- false, {autoPad, filterLayout});
+ dequantize, '170',
+ {scale: [0.002771724946796894], zero_point: [105], shape: []},
+ {scale: [0.00006521832983708009], zero_point: [0], shape: []},
+ {scale: [0.06046031787991524], zero_point: [60], shape: []},
+ false, {autoPad, filterLayout});
const reshape = this.builder_.reshape(conv4, [1, 1001]);
const softmax = this.builder_.softmax(reshape, 1);
-
+
return softmax;
}
diff --git a/image_classification/resnet50v2_nchw.js b/image_classification/resnet50v2_nchw.js
index 27a72c4d..0f4375ee 100644
--- a/image_classification/resnet50v2_nchw.js
+++ b/image_classification/resnet50v2_nchw.js
@@ -4,12 +4,13 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
// ResNet50 V2 model with 'nchw' input layout
export class ResNet50V2Nchw {
- constructor() {
+ constructor(dataType = 'float32') {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
this.outputTensor_ = null;
+ this.targetDataType_ = dataType;
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/resnet50v2_nchw/weights/';
this.inputOptions = {
@@ -32,7 +33,8 @@ export class ResNet50V2Nchw {
prefix = this.weightsUrl_ + 'resnetv24_conv' + name;
}
const weightName = prefix + '_weight.npy';
- const weight = buildConstantByNpy(this.builder_, weightName);
+ const weight = buildConstantByNpy(
+ this.builder_, weightName, this.targetDataType_);
return this.builder_.conv2d(await input, await weight, options);
}
@@ -48,10 +50,14 @@ export class ResNet50V2Nchw {
const biasName = prefix + '_beta.npy';
const meanName = prefix + '_running_mean.npy';
const varName = prefix + '_running_var.npy';
- const scale = buildConstantByNpy(this.builder_, scaleName);
- const bias = buildConstantByNpy(this.builder_, biasName);
- const mean = buildConstantByNpy(this.builder_, meanName);
- const variance = buildConstantByNpy(this.builder_, varName);
+ const scale = buildConstantByNpy(
+ this.builder_, scaleName, this.targetDataType_);
+ const bias = buildConstantByNpy(
+ this.builder_, biasName, this.targetDataType_);
+ const mean = buildConstantByNpy(
+ this.builder_, meanName, this.targetDataType_);
+ const variance = buildConstantByNpy(
+ this.builder_, varName, this.targetDataType_);
const options = {scale: await scale, bias: await bias};
const batchnorm = this.builder_.batchNormalization(
await input,
@@ -65,9 +71,11 @@ export class ResNet50V2Nchw {
async buildGemm_(input, name) {
const prefix = this.weightsUrl_ + 'resnetv24_dense' + name;
const weightName = prefix + '_weight.npy';
- const weight = buildConstantByNpy(this.builder_, weightName);
+ const weight = buildConstantByNpy(
+ this.builder_, weightName, this.targetDataType_);
const biasName = prefix + '_bias.npy';
- const bias = buildConstantByNpy(this.builder_, biasName);
+ const bias = buildConstantByNpy(
+ this.builder_, biasName, this.targetDataType_);
const options =
{c: this.builder_.reshape(await bias, [1, 1000]), bTranspose: true};
return this.builder_.gemm(await input, await weight, options);
@@ -105,7 +113,7 @@ export class ResNet50V2Nchw {
dimensions: this.inputOptions.inputShape,
shape: this.inputOptions.inputShape,
};
- const data = this.builder_.input('input', inputDesc);
+ let data = this.builder_.input('input', inputDesc);
inputDesc.usage = MLTensorUsage.WRITE;
inputDesc.writable = true;
this.inputTensor_ = await this.context_.createTensor(inputDesc);
@@ -116,6 +124,10 @@ export class ResNet50V2Nchw {
usage: MLTensorUsage.READ,
readable: true,
});
+
+ if (this.targetDataType_ === 'float16') {
+ data = this.builder_.cast(data, 'float16');
+ }
const bn1 = this.buildBatchNorm_(data, '0', '', false);
const conv0 = this.buildConv_(
bn1, '0', '', {padding: [3, 3, 3, 3], strides: [2, 2]});
@@ -167,7 +179,12 @@ export class ResNet50V2Nchw {
const pool2 = this.builder_.averagePool2d(await bn3);
const reshape = this.builder_.reshape(await pool2, [1, 2048]);
const gemm = this.buildGemm_(await reshape, '0');
- return this.builder_.softmax(await gemm, 1);
+ const softmax = this.builder_.softmax(await gemm, 1);
+
+ if (this.targetDataType_ === 'float16') {
+ return this.builder_.cast(softmax, 'float32');
+ }
+ return softmax;
}
async build(outputOperand) {
diff --git a/image_classification/resnet50v2_nhwc.js b/image_classification/resnet50v2_nhwc.js
index 4c110ee6..7d1bf189 100644
--- a/image_classification/resnet50v2_nhwc.js
+++ b/image_classification/resnet50v2_nhwc.js
@@ -8,12 +8,13 @@ const layout = 'nhwc';
// ResNet 50 V2 model with 'nhwc' layout
export class ResNet50V2Nhwc {
- constructor() {
+ constructor(dataType = 'float32') {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
this.outputTensor_ = null;
+ this.targetDataType_ = dataType;
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/resnet50v2_nhwc/weights/';
this.inputOptions = {
@@ -44,9 +45,11 @@ export class ResNet50V2Nhwc {
prefix += 'conv' + nameIndices[2];
}
const weightsName = prefix + '_weights.npy';
- const weights = await buildConstantByNpy(this.builder_, weightsName);
+ const weights = await buildConstantByNpy(
+ this.builder_, weightsName, this.targetDataType_);
const biasName = prefix + '_Conv2D_bias.npy';
- const bias = buildConstantByNpy(this.builder_, biasName);
+ const bias = buildConstantByNpy(
+ this.builder_, biasName, this.targetDataType_);
options.inputLayout = layout;
options.filterLayout = 'ohwi';
options.bias = await bias;
@@ -75,9 +78,11 @@ export class ResNet50V2Nhwc {
`block${nameIndices[0]}_unit_${nameIndices[1]}_bottleneck_v2_preact`;
}
const mulParamName = prefix + '_FusedBatchNorm_mul_0_param.npy';
- const mulParam = buildConstantByNpy(this.builder_, mulParamName);
+ const mulParam = buildConstantByNpy(
+ this.builder_, mulParamName, this.targetDataType_);
const addParamName = prefix + '_FusedBatchNorm_add_param.npy';
- const addParam = buildConstantByNpy(this.builder_, addParamName);
+ const addParam = buildConstantByNpy(
+ this.builder_, addParamName, this.targetDataType_);
return this.builder_.relu(
this.builder_.add(
this.builder_.mul(await input, await mulParam),
@@ -131,7 +136,7 @@ export class ResNet50V2Nhwc {
dimensions: this.inputOptions.inputShape,
shape: this.inputOptions.inputShape,
};
- const input = this.builder_.input('input', inputDesc);
+ let input = this.builder_.input('input', inputDesc);
inputDesc.usage = MLTensorUsage.WRITE;
inputDesc.writable = true;
this.inputTensor_ = await this.context_.createTensor(inputDesc);
@@ -142,6 +147,10 @@ export class ResNet50V2Nhwc {
usage: MLTensorUsage.READ,
readable: true,
});
+
+ if (this.targetDataType_ === 'float16') {
+ input = this.builder_.cast(input, 'float16');
+ }
const conv1 = await this.buildConv_(
input, ['', '', '1'], {strides, padding: [3, 3, 3, 3]}, false);
const windowDimensions = [3, 3];
@@ -201,7 +210,12 @@ export class ResNet50V2Nhwc {
const conv2 = this.buildConv_(
mean, ['', '', 'logits'], {autoPad}, false);
const reshape = this.builder_.reshape(await conv2, [1, 1001]);
- return this.builder_.softmax(reshape, 1);
+ const softmax = this.builder_.softmax(reshape, 1);
+
+ if (this.targetDataType_ === 'float16') {
+ return this.builder_.cast(softmax, 'float32');
+ }
+ return softmax;
}
async build(outputOperand) {
diff --git a/image_classification/squeezenet_nchw.js b/image_classification/squeezenet_nchw.js
index c76330ed..e35f981d 100644
--- a/image_classification/squeezenet_nchw.js
+++ b/image_classification/squeezenet_nchw.js
@@ -4,12 +4,13 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
// SqueezeNet 1.1 model with 'nchw' input layout
export class SqueezeNetNchw {
- constructor() {
+ constructor(dataType = 'float32') {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
this.outputTensor_ = null;
+ this.targetDataType_ = dataType;
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/squeezenet1.1_nchw/weights/';
this.inputOptions = {
@@ -26,9 +27,11 @@ export class SqueezeNetNchw {
async buildConv_(input, name, options = {}) {
const prefix = this.weightsUrl_ + 'squeezenet0_' + name;
const weightsName = prefix + '_weight.npy';
- const weights = buildConstantByNpy(this.builder_, weightsName);
+ const weights = buildConstantByNpy(
+ this.builder_, weightsName, this.targetDataType_);
const biasName = prefix + '_bias.npy';
- const bias = buildConstantByNpy(this.builder_, biasName);
+ const bias = buildConstantByNpy(
+ this.builder_, biasName, this.targetDataType_);
options.bias = await bias;
const conv2d = this.builder_.conv2d(await input, await weights, options);
return this.builder_.relu(conv2d);
@@ -50,7 +53,7 @@ export class SqueezeNetNchw {
dimensions: this.inputOptions.inputShape,
shape: this.inputOptions.inputShape,
};
- const data = this.builder_.input('input', inputDesc);
+ let data = this.builder_.input('input', inputDesc);
inputDesc.usage = MLTensorUsage.WRITE;
inputDesc.writable = true;
this.inputTensor_ = await this.context_.createTensor(inputDesc);
@@ -61,6 +64,10 @@ export class SqueezeNetNchw {
usage: MLTensorUsage.READ,
readable: true,
});
+
+ if (this.targetDataType_ === 'float16') {
+ data = this.builder_.cast(data, 'float16');
+ }
const conv0 = this.buildConv_(data, 'conv0', {strides: [2, 2]});
const pool0 = this.builder_.maxPool2d(
await conv0, {windowDimensions: [3, 3], strides: [2, 2]});
@@ -80,7 +87,12 @@ export class SqueezeNetNchw {
const pool3 = this.builder_.averagePool2d(
await conv25, {windowDimensions: [13, 13], strides: [13, 13]});
const reshape0 = this.builder_.reshape(pool3, [1, 1000]);
- return this.builder_.softmax(reshape0, 1);
+ const softmax = this.builder_.softmax(reshape0, 1);
+
+ if (this.targetDataType_ === 'float16') {
+ return this.builder_.cast(softmax, 'float32');
+ }
+ return softmax;
}
async build(outputOperand) {
diff --git a/image_classification/squeezenet_nhwc.js b/image_classification/squeezenet_nhwc.js
index debbe9b6..490095da 100644
--- a/image_classification/squeezenet_nhwc.js
+++ b/image_classification/squeezenet_nhwc.js
@@ -4,12 +4,13 @@ import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../
// SqueezeNet 1.0 model with 'nhwc' layout
export class SqueezeNetNhwc {
- constructor() {
+ constructor(dataType = 'float32') {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.inputTensor_ = null;
this.outputTensor_ = null;
+ this.targetDataType_ = dataType;
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/squeezenet1.0_nhwc/weights/';
this.inputOptions = {
@@ -25,9 +26,11 @@ export class SqueezeNetNhwc {
async buildConv_(input, name, options = {}) {
const prefix = this.weightsUrl_ + name;
const weightsName = prefix + '_kernel.npy';
- const weights = await buildConstantByNpy(this.builder_, weightsName);
+ const weights = await buildConstantByNpy(
+ this.builder_, weightsName, this.targetDataType_);
const biasName = prefix + '_Conv2D_bias.npy';
- const bias = buildConstantByNpy(this.builder_, biasName);
+ const bias = buildConstantByNpy(
+ this.builder_, biasName, this.targetDataType_);
options.inputLayout = 'nhwc';
options.filterLayout = 'ohwi';
options.bias = await bias;
@@ -65,7 +68,7 @@ export class SqueezeNetNhwc {
dimensions: this.inputOptions.inputShape,
shape: this.inputOptions.inputShape,
};
- const placeholder = this.builder_.input('input', inputDesc);
+ let placeholder = this.builder_.input('input', inputDesc);
inputDesc.usage = MLTensorUsage.WRITE;
inputDesc.writable = true;
this.inputTensor_ = await this.context_.createTensor(inputDesc);
@@ -76,6 +79,10 @@ export class SqueezeNetNhwc {
usage: MLTensorUsage.READ,
readable: true,
});
+
+ if (this.targetDataType_ === 'float16') {
+ placeholder = this.builder_.cast(placeholder, 'float16');
+ }
const conv1 = this.buildConv_(
placeholder, 'conv1', {strides, autoPad: 'same-upper'});
const maxpool1 = this.builder_.maxPool2d(
@@ -96,7 +103,12 @@ export class SqueezeNetNhwc {
const averagePool2d = this.builder_.averagePool2d(
await conv10, {windowDimensions: [13, 13], layout});
const reshape = this.builder_.reshape(averagePool2d, [1, 1001]);
- return this.builder_.softmax(reshape, 1);
+ const softmax = this.builder_.softmax(reshape, 1);
+
+ if (this.targetDataType_ === 'float16') {
+ return this.builder_.cast(softmax, 'float32');
+ }
+ return softmax;
}
async build(outputOperand) {
diff --git a/object_detection/index.html b/object_detection/index.html
index 90aa8985..d94ad7ca 100644
--- a/object_detection/index.html
+++ b/object_detection/index.html
@@ -30,15 +30,15 @@
Backend