Skip to content

Commit 4d838cb

Browse files
committed
[tfjs-core] backend for encodeBase64, decodeBase64
1 parent 2a99938 commit 4d838cb

File tree

4 files changed

+96
-6
lines changed

4 files changed

+96
-6
lines changed

tfjs-core/src/backends/backend.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google Inc. All Rights Reserved.
3+
* Copyright 2019 Google Inc. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -17,7 +17,7 @@
1717

1818
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
1919
import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util';
20-
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
20+
import {Backend, DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
2121
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';
2222

2323
export const EPSILON_FLOAT32 = 1e-7;
@@ -650,4 +650,13 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
650650
dispose(): void {
651651
throw new Error('Not yet implemented');
652652
}
653+
654+
encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
655+
T {
656+
throw new Error('Not yet implemented');
657+
}
658+
659+
decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
660+
throw new Error('Not yet implemented');
661+
}
653662
}

tfjs-core/src/backends/cpu/backend_cpu.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2017 Google Inc. All Rights Reserved.
3+
* Copyright 2019 Google Inc. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -34,7 +34,7 @@ import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
3434
import * as scatter_nd_util from '../../ops/scatter_nd_util';
3535
import * as selu_util from '../../ops/selu_util';
3636
import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util';
37-
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
37+
import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
3838
import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
3939
import * as util from '../../util';
4040
import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
@@ -43,6 +43,7 @@ import * as backend_util from '../backend_util';
4343
import * as complex_util from '../complex_util';
4444
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
4545
import {split} from '../split_shared';
46+
import {decodeBase64Impl, encodeBase64Impl} from '../string_shared';
4647
import {tile} from '../tile_impl';
4748
import {topkImpl} from '../topk_impl';
4849
import {whereImpl} from '../where_impl';
@@ -3624,6 +3625,17 @@ export class MathBackendCPU implements KernelBackend {
36243625

36253626
dispose() {}
36263627

3628+
encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
3629+
T {
3630+
const sVals = this.readSync(str.dataId) as Uint8Array[];
3631+
return encodeBase64Impl(sVals, str.shape, pad);
3632+
}
3633+
3634+
decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
3635+
const sVals = this.readSync(str.dataId) as Uint8Array[];
3636+
return decodeBase64Impl(sVals, str.shape);
3637+
}
3638+
36273639
floatPrecision(): 16|32 {
36283640
return 32;
36293641
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {arrayBufferToBase64String, base64StringToArrayBuffer, urlSafeBase64, urlUnsafeBase64} from '../io/io_utils';
19+
import {StringTensor, Tensor} from '../tensor';
20+
import {decodeString} from '../util';
21+
22+
/** Shared implementation of the encodeBase64 kernel across WebGL and CPU. */
23+
export function encodeBase64Impl<T extends StringTensor>(
24+
values: Uint8Array[], shape: number[], pad = false): T {
25+
const resultValues = new Array(values.length);
26+
27+
for (let i = 0; i < values.length; ++i) {
28+
const bStr = arrayBufferToBase64String(values[i].buffer);
29+
const bStrUrl = urlSafeBase64(bStr);
30+
31+
if (pad) {
32+
resultValues[i] = bStrUrl;
33+
} else {
34+
// Remove padding
35+
resultValues[i] = bStrUrl.replace(/=/g, '');
36+
}
37+
}
38+
39+
return Tensor.make(shape, {values: resultValues}, 'string');
40+
}
41+
42+
/** Shared implementation of the decodeBase64 kernel across WebGL and CPU. */
43+
export function decodeBase64Impl<T extends StringTensor>(
44+
values: Uint8Array[], shape: number[]): T {
45+
const resultValues = new Array(values.length);
46+
47+
for (let i = 0; i < values.length; ++i) {
48+
// Undo URL safe and decode from Base64 to ArrayBuffer
49+
const bStrUrl = decodeString(values[i]);
50+
const bStr = urlUnsafeBase64(bStrUrl);
51+
const aBuff = base64StringToArrayBuffer(bStr);
52+
53+
resultValues[i] = decodeString(new Uint8Array(aBuff));
54+
}
55+
56+
return Tensor.make(shape, {values: resultValues}, 'string');
57+
}

tfjs-core/src/backends/webgl/backend_webgl.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2017 Google Inc. All Rights Reserved.
3+
* Copyright 2019 Google Inc. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -37,7 +37,7 @@ import * as segment_util from '../../ops/segment_util';
3737
import * as slice_util from '../../ops/slice_util';
3838
import {softmax} from '../../ops/softmax';
3939
import {range, scalar, tensor} from '../../ops/tensor_ops';
40-
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
40+
import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
4141
import {BackendValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
4242
import * as util from '../../util';
4343
import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';
@@ -46,6 +46,7 @@ import * as backend_util from '../backend_util';
4646
import {mergeRealAndImagArrays} from '../complex_util';
4747
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
4848
import {split} from '../split_shared';
49+
import {decodeBase64Impl, encodeBase64Impl} from '../string_shared';
4950
import {tile} from '../tile_impl';
5051
import {topkImpl} from '../topk_impl';
5152
import {whereImpl} from '../where_impl';
@@ -2412,6 +2413,17 @@ export class MathBackendWebGL implements KernelBackend {
24122413
return split(x, sizeSplits, axis);
24132414
}
24142415

2416+
encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
2417+
T {
2418+
const sVals = this.readSync(str.dataId) as Uint8Array[];
2419+
return encodeBase64Impl(sVals, str.shape, pad);
2420+
}
2421+
2422+
decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
2423+
const sVals = this.readSync(str.dataId) as Uint8Array[];
2424+
return decodeBase64Impl(sVals, str.shape);
2425+
}
2426+
24152427
scatterND<R extends Rank>(
24162428
indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor<R> {
24172429
const {sliceRank, numUpdates, sliceSize, strides, outputSize} =

0 commit comments

Comments
 (0)