Skip to content

Commit d06dc7e

Browse files
authored
Add object detection API for AutoML (#1971)
FEATURE - Add pre and post-process ops for the model (fast non-max-suppression) - Add the API for loading and running on object detection model trained by AutoML - Add unit tests for both node and the browser environment using a real model produced by AutoML
1 parent f39acb2 commit d06dc7e

21 files changed

+6736
-66
lines changed

tfjs-automl/README.md

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ If you are using CDN:
2222

2323
We support the following types of AutoML Edge models:
2424
1) [Image classification](#image-classification)
25-
2) **[In progress]** [Object detection](#object-detection)
25+
2) [Object detection](#object-detection)
2626

2727
## Image classification
2828

@@ -67,7 +67,7 @@ a 3D [`Tensor`](https://js.tensorflow.org/api/latest/#class:Tensor):
6767

6868
```js
6969
const img = document.getElementById('img');
70-
const options = {};
70+
const options = {centerCrop: true};
7171
const predictions = await model.classify(img, options);
7272
```
7373

@@ -88,6 +88,90 @@ probabilities:
8888
]
8989
```
9090

91+
### Advanced usage
92+
93+
Advanced users can access the underlying
94+
[`GraphModel`](https://js.tensorflow.org/api/latest/#class:GraphModel) via
95+
`model.graphModel`. The `GraphModel` allows users to call lower level methods
96+
such as `predict()`, `execute()` and `executeAsync()` which return tensors.
97+
98+
`model.dictionary` gives you access to the ordered list of labels.
99+
91100
## Object detection
92101

93-
TODO(smilkov): Write this when object detection is ready.
102+
AutoML Object detection model will output the following set of files:
103+
- `model.json`, the model topology
104+
- `dict.txt`, a newline-separated list of labels
105+
- One or more of `*.bin` files which hold the weights
106+
107+
Make sure you can access those files as static assets from your web app by serving them locally or on Google Cloud Storage.
108+
109+
### Demo
110+
111+
The object detection demo lives in
112+
[demo/object_classification](./demo/object_classification). To run it:
113+
114+
```sh
115+
cd demo/object_detection
116+
yarn
117+
yarn watch
118+
```
119+
120+
This will start a local HTTP server on port 1234 that serves the demo.
121+
122+
### Loading the model
123+
```js
124+
import * as automl from '@tensorflow/tfjs-automl';
125+
const modelUrl = 'model.json'; // URL to the model.json file.
126+
const model = await automl.loadObjectDetection(modelUrl);
127+
```
128+
129+
### Making a prediction
130+
The input `img` can be
131+
[`HTMLImageElement`](https://developer.mozilla.org/en-US/docs/Web/API/HTMLImageElement),
132+
[`HTMLCanvasElement`](https://developer.mozilla.org/en-US/docs/Web/API/HTMLCanvasElement),
133+
[`HTMLVideoElement`](https://developer.mozilla.org/en-US/docs/Web/API/HTMLVideoElement),
134+
[`ImageData`](https://developer.mozilla.org/en-US/docs/Web/API/ImageData) or
135+
a 3D [`Tensor`](https://js.tensorflow.org/api/latest/#class:Tensor):
136+
137+
```html
138+
<img id="img" src="PATH_TO_IMAGE" />
139+
```
140+
141+
```js
142+
const img = document.getElementById('img');
143+
const options = {score: 0.5, iou: 0.5, topk: 20};
144+
const predictions = await model.detect(img, options);
145+
```
146+
147+
`options` is optional and has the following properties:
148+
- `score` - Probability score between 0 and 1. Defaults to 0.5. Boxes with score lower than this threshold will be ignored.
149+
- `topk` - Only the `topk` most likely objects are returned. The actual number of objects might be less than this number.
150+
- `iou` - Intersection over union threshold. IoU is a metric between 0 and 1 used to measure the overlap of two boxes. The predicted boxes will not overlap more than the specified threshold.
151+
152+
The result `predictions` is a sorted list of predicted objects:
153+
154+
```js
155+
[
156+
{
157+
box: {
158+
left: 105.1,
159+
top: 22.2,
160+
width: 70.6,
161+
height: 55.7
162+
},
163+
label: "Tomato",
164+
score: 0.972
165+
},
166+
...
167+
]
168+
```
169+
170+
### Advanced usage
171+
172+
Advanced users can access the underlying
173+
[`GraphModel`](https://js.tensorflow.org/api/latest/#class:GraphModel) via
174+
`model.graphModel`. The `GraphModel` allows users to call lower level methods
175+
such as `predict()`, `execute()` and `executeAsync()` which return tensors.
176+
177+
`model.dictionary` gives you access to the ordered list of labels.

tfjs-automl/demo/img_classification/package.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
2-
"name": "automl-demo",
2+
"name": "automl-img-classification-demo",
33
"version": "0.0.1",
44
"private": true,
5-
"description": "Demo of using automl client",
5+
"description": "Image classification demo using the AutoML NPM library",
66
"main": "index.js",
77
"scripts": {
88
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open ",
@@ -18,5 +18,9 @@
1818
"parcel-bundler": "~1.10.3",
1919
"yalc": "~1.0.0-pre.27"
2020
},
21+
"dependencies": {
22+
"@tensorflow/tfjs-converter": "^1.2.8",
23+
"@tensorflow/tfjs-core": "^1.2.8"
24+
},
2125
"license": "Apache-2.0"
2226
}

tfjs-automl/demo/img_classification/yarn.lock

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,11 +697,48 @@
697697
resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-1.1.3.tgz#2b5a3ab3f918cca48a8c754c08168e3f03eba61b"
698698
integrity sha512-shAmDyaQC4H92APFoIaVDHCx5bStIocgvbwQyxPRrbUY20V1EYTbSDchWbuwlMG3V17cprZhA6+78JfB+3DTPw==
699699

700+
"@tensorflow/tfjs-converter@^1.2.8":
701+
version "1.2.8"
702+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-1.2.8.tgz#86fa47be3e92a90d4191956f08015a17b93c3ef9"
703+
integrity sha512-weHzkNRVgnY9TcbA3XTneNgCyuIXLjF8ks8YbFA+81i2w6qO90xiAdWtP2YmR+F9K9S4WR3bSSB0AQKZAp+mPQ==
704+
705+
"@tensorflow/tfjs-core@^1.2.8":
706+
version "1.2.8"
707+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.2.8.tgz#d6873b88522f8cf25d34c10afd095866578d7d92"
708+
integrity sha512-lWV4vAnXAAmahXpCWBwdGGW9HO6iNw9pUeVYih7pDXeJahMk3OJs6SgjRNhwn+ldsGwRoorR0/RHg0yNLmqWxQ==
709+
dependencies:
710+
"@types/offscreencanvas" "~2019.3.0"
711+
"@types/seedrandom" "2.4.27"
712+
"@types/webgl-ext" "0.0.30"
713+
"@types/webgl2" "0.0.4"
714+
node-fetch "~2.1.2"
715+
seedrandom "2.4.3"
716+
717+
"@types/offscreencanvas@~2019.3.0":
718+
version "2019.3.0"
719+
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.3.0.tgz#3336428ec7e9180cf4566dfea5da04eb586a6553"
720+
integrity sha512-esIJx9bQg+QYF0ra8GnvfianIY8qWB0GBx54PK5Eps6m+xTj86KLavHv6qDhzKcu5UUOgNfJ2pWaIIV7TRUd9Q==
721+
700722
"@types/q@^1.5.1":
701723
version "1.5.2"
702724
resolved "https://registry.yarnpkg.com/@types/q/-/q-1.5.2.tgz#690a1475b84f2a884fd07cd797c00f5f31356ea8"
703725
integrity sha512-ce5d3q03Ex0sy4R14722Rmt6MT07Ua+k4FwDfdcToYJcMKNtRVQvJ6JCAPdAmAnbRb6CsX6aYb9m96NGod9uTw==
704726

727+
"@types/seedrandom@2.4.27":
728+
version "2.4.27"
729+
resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.27.tgz#9db563937dd86915f69092bc43259d2f48578e41"
730+
integrity sha1-nbVjk33YaRX2kJK8QyWdL0hXjkE=
731+
732+
"@types/webgl-ext@0.0.30":
733+
version "0.0.30"
734+
resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d"
735+
integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg==
736+
737+
"@types/webgl2@0.0.4":
738+
version "0.0.4"
739+
resolved "https://registry.yarnpkg.com/@types/webgl2/-/webgl2-0.0.4.tgz#c3b0f9d6b465c66138e84e64cb3bdf8373c2c279"
740+
integrity sha512-PACt1xdErJbMUOUweSrbVM7gSIYm1vTncW2hF6Os/EeWi6TXYAYMPp+8v6rzHmypE5gHrxaxZNXgMkJVIdZpHw==
741+
705742
abbrev@1:
706743
version "1.1.1"
707744
resolved "https://registry.yarnpkg.com/abbrev/-/abbrev-1.1.1.tgz#f8f2c887ad10bf67f634f005b6987fed3179aac8"
@@ -3688,6 +3725,11 @@ node-addon-api@^1.6.0:
36883725
resolved "https://registry.yarnpkg.com/node-addon-api/-/node-addon-api-1.7.1.tgz#cf813cd69bb8d9100f6bdca6755fc268f54ac492"
36893726
integrity sha512-2+DuKodWvwRTrCfKOeR24KIc5unKjOh8mz17NCzVnHWfjAdDqbfbjqh7gUT+BkXBRQM52+xCHciKWonJ3CbJMQ==
36903727

3728+
node-fetch@~2.1.2:
3729+
version "2.1.2"
3730+
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.1.2.tgz#ab884e8e7e57e38a944753cec706f788d1768bb5"
3731+
integrity sha1-q4hOjn5X44qUR1POxwb3iNF2i7U=
3732+
36913733
node-forge@^0.7.1:
36923734
version "0.7.6"
36933735
resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.7.6.tgz#fdf3b418aee1f94f0ef642cd63486c77ca9724ac"
@@ -5121,6 +5163,11 @@ sax@^1.2.4, sax@~1.2.1, sax@~1.2.4:
51215163
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
51225164
integrity sha512-NqVDv9TpANUjFm0N8uM5GxL36UgKi9/atZw+x7YFnQ8ckwFGKrl4xX4yWtrey3UJm5nP1kUbnYgLopqWNSRhWw==
51235165

5166+
seedrandom@2.4.3:
5167+
version "2.4.3"
5168+
resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-2.4.3.tgz#2438504dad33917314bff18ac4d794f16d6aaecc"
5169+
integrity sha1-JDhQTa0zkXMUv/GKxNeU8W1qrsw=
5170+
51245171
"semver@2 || 3 || 4 || 5", semver@^5.3.0, semver@^5.4.1, semver@^5.5.0, semver@^5.6.0:
51255172
version "5.7.1"
51265173
resolved "https://registry.yarnpkg.com/semver/-/semver-5.7.1.tgz#a954f931aeba508d307bbf069eff0c01c96116f7"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": []
16+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<head>
4+
<title>AutoML Object detection demo</title>
5+
<style>
6+
</style>
7+
</head>
8+
<body>
9+
<div id="imgDiv" style="position:relative;">
10+
<img id="salad" width="500" crossorigin="anonymous" src="https://storage.googleapis.com/tfjs-testing/tfjs-automl/object_detection/test_image.jpg" />
11+
<svg width="500" height="375" style="position: absolute;top:0;left:0;">
12+
<style>
13+
.box {
14+
stroke-width: 2;
15+
fill: none;
16+
stroke: red;
17+
}
18+
.label {
19+
font-size: 12px;
20+
fill: white;
21+
text-anchor: middle;
22+
}
23+
.label-rect {
24+
fill: black;
25+
}
26+
</style>
27+
</svg>
28+
</div>
29+
<script src="index.js"></script>
30+
</body>
31+
</html>
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
// TODO(smilkov): Import from "@tensoflow/tfjs-automl" when the package
19+
// is released.
20+
import * as automl from '../../src/index';
21+
22+
const MODEL_URL =
23+
'https://storage.googleapis.com/tfjs-testing/tfjs-automl/object_detection/model.json';
24+
25+
async function run() {
26+
const model = await automl.loadObjectDetection(MODEL_URL);
27+
const image = document.getElementById('salad');
28+
// These are the default options.
29+
const options = {score: 0.5, iou: 0.5, topk: 20};
30+
const predictions = await model.detect(image, options);
31+
drawBoxes(predictions);
32+
}
33+
34+
// Overlays boxes with labels onto the image using `rect` and `text` svg
35+
// elements.
36+
function drawBoxes(predictions) {
37+
const svg = document.querySelector('svg');
38+
predictions.forEach(prediction => {
39+
const {box, label, score} = prediction;
40+
const {left, top, width, height} = box;
41+
const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect');
42+
rect.setAttribute('width', width);
43+
rect.setAttribute('height', height);
44+
rect.setAttribute('x', left);
45+
rect.setAttribute('y', top);
46+
rect.setAttribute('class', 'box');
47+
const text = document.createElementNS('http://www.w3.org/2000/svg', 'text');
48+
text.setAttribute('x', left + width / 2);
49+
text.setAttribute('y', top);
50+
text.setAttribute('dy', 12);
51+
text.setAttribute('class', 'label');
52+
text.textContent = `${label}: ${score.toFixed(3)}`;
53+
svg.appendChild(rect);
54+
svg.appendChild(text);
55+
const textBBox = text.getBBox();
56+
const textRect =
57+
document.createElementNS('http://www.w3.org/2000/svg', 'rect');
58+
textRect.setAttribute('x', textBBox.x);
59+
textRect.setAttribute('y', textBBox.y);
60+
textRect.setAttribute('width', textBBox.width);
61+
textRect.setAttribute('height', textBBox.height);
62+
textRect.setAttribute('class', 'label-rect');
63+
svg.insertBefore(textRect, text);
64+
});
65+
}
66+
67+
run();
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"name": "automl-object-detection-demo",
3+
"version": "0.0.1",
4+
"private": true,
5+
"description": "Object detection demo using the AutoML NPM library",
6+
"main": "index.js",
7+
"scripts": {
8+
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open ",
9+
"build": "cross-env NODE_ENV=production parcel build index.html --public-url ./"
10+
},
11+
"devDependencies": {
12+
"babel-core": "^6.26.3",
13+
"babel-plugin-transform-runtime": "~6.23.0",
14+
"babel-polyfill": "~6.26.0",
15+
"babel-preset-env": "~1.6.1",
16+
"clang-format": "~1.2.2",
17+
"cross-env": "^5.2.0",
18+
"parcel-bundler": "~1.10.3",
19+
"yalc": "~1.0.0-pre.27"
20+
},
21+
"license": "Apache-2.0",
22+
"dependencies": {
23+
"@tensorflow/tfjs-converter": "^1.2.8",
24+
"@tensorflow/tfjs-core": "^1.2.8"
25+
}
26+
}

0 commit comments

Comments
 (0)