Skip to content

Commit 083ad23

Browse files
authored
Merge pull request #330 from Honry/fix-softmax
Use latest softmax API
2 parents 3ff6eef + 342c79c commit 083ad23

File tree

10 files changed

+11
-11
lines changed

10 files changed

+11
-11
lines changed

image_classification/efficientnet_fp16_nchw.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ export class EfficientNetFP16Nchw {
161161
const pool1 = this.builder_.averagePool2d(await conv22);
162162
const reshape = this.builder_.reshape(pool1, [1, 1280]);
163163
const gemm = this.buildGemm_(reshape, '0');
164-
const softmax = this.builder_.softmax(await gemm);
164+
const softmax = this.builder_.softmax(await gemm, 1);
165165

166166
return this.builder_.cast(softmax, 'float32');
167167
}

image_classification/mobilenet_nchw.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,13 @@ export class MobileNetV2Nchw {
153153
const pool = this.builder_.averagePool2d(await conv3);
154154
const reshape = this.builder_.reshape(pool, [1, 1280]);
155155
const gemm = this.buildGemm_(reshape, '104');
156-
return this.builder_.softmax(await gemm);
156+
return this.builder_.softmax(await gemm, 1);
157157
} else {
158158
const conv4 = this.buildConv_(await conv3, '97', false,
159159
{groups: 1280, strides: [7, 7]});
160160
const conv5 = this.buildConv_(await conv4, '104', false);
161161
const reshape = this.builder_.reshape(await conv5, [1, 1000]);
162-
const softmax = this.builder_.softmax(reshape);
162+
const softmax = this.builder_.softmax(reshape, 1);
163163
return this.builder_.cast(softmax, 'float32');
164164
}
165165
}

image_classification/mobilenet_nhwc.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ export class MobileNetV2Nhwc {
153153
const conv4 = this.buildConv_(
154154
averagePool2d, '222', 'Logits_Conv2d_1c_1x1_Conv2D', false, {autoPad, filterLayout});
155155
const reshape = this.builder_.reshape(await conv4, [1, 1001]);
156-
return await this.builder_.softmax(reshape);
156+
return await this.builder_.softmax(reshape, 1);
157157
}
158158

159159
async build(outputOperand) {

image_classification/mobilenet_uint8_nhwc.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ export class MobileNetV2Uint8Nhwc {
465465
{scale: [0.06046031787991524], zero_point: [60], shape: []},
466466
false, {autoPad, filterLayout});
467467
const reshape = this.builder_.reshape(conv4, [1, 1001]);
468-
const softmax = this.builder_.softmax(reshape);
468+
const softmax = this.builder_.softmax(reshape, 1);
469469

470470
return softmax;
471471
}

image_classification/resnet50v1_fp16_nchw.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ export class ResNet50V1FP16Nchw {
130130
const pool2 = this.builder_.averagePool2d(await bottleneck16);
131131
const reshape = this.builder_.reshape(pool2, [1, 2048]);
132132
const gemm = this.buildGemm_(reshape, '0');
133-
const softmax = this.builder_.softmax(await gemm);
133+
const softmax = this.builder_.softmax(await gemm, 1);
134134
return this.builder_.cast(softmax, 'float32');
135135
}
136136

image_classification/resnet50v2_nchw.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ export class ResNet50V2Nchw {
167167
const pool2 = this.builder_.averagePool2d(await bn3);
168168
const reshape = this.builder_.reshape(await pool2, [1, 2048]);
169169
const gemm = this.buildGemm_(await reshape, '0');
170-
return this.builder_.softmax(await gemm);
170+
return this.builder_.softmax(await gemm, 1);
171171
}
172172

173173
async build(outputOperand) {

image_classification/resnet50v2_nhwc.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ export class ResNet50V2Nhwc {
201201
const conv2 = this.buildConv_(
202202
mean, ['', '', 'logits'], {autoPad}, false);
203203
const reshape = this.builder_.reshape(await conv2, [1, 1001]);
204-
return this.builder_.softmax(reshape);
204+
return this.builder_.softmax(reshape, 1);
205205
}
206206

207207
async build(outputOperand) {

image_classification/squeezenet_nchw.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ export class SqueezeNetNchw {
8080
const pool3 = this.builder_.averagePool2d(
8181
await conv25, {windowDimensions: [13, 13], strides: [13, 13]});
8282
const reshape0 = this.builder_.reshape(pool3, [1, 1000]);
83-
return this.builder_.softmax(reshape0);
83+
return this.builder_.softmax(reshape0, 1);
8484
}
8585

8686
async build(outputOperand) {

image_classification/squeezenet_nhwc.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ export class SqueezeNetNhwc {
9696
const averagePool2d = this.builder_.averagePool2d(
9797
await conv10, {windowDimensions: [13, 13], layout});
9898
const reshape = this.builder_.reshape(averagePool2d, [1, 1001]);
99-
return this.builder_.softmax(reshape);
99+
return this.builder_.softmax(reshape, 1);
100100
}
101101

102102
async build(outputOperand) {

lenet/lenet.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ export class LeNet {
174174
new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add4BiasShape)));
175175
const add4 = this.builder_.add(matmul2, add4Bias);
176176

177-
return this.builder_.softmax(add4);
177+
return this.builder_.softmax(add4, 1);
178178
}
179179

180180
async build(outputOperand) {

0 commit comments

Comments
 (0)