Skip to content

Commit 25398a5

Browse files
committed
[tfjs-core] encodeBase64, decodeBase64 ops
1 parent 4d838cb commit 25398a5

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

tfjs-core/src/ops/ops.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google Inc. All Rights Reserved.
3+
* Copyright 2019 Google Inc. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -49,6 +49,7 @@ export * from './gather_nd';
4949
export * from './diag';
5050
export * from './dropout';
5151
export * from './signal_ops';
52+
export * from './string_ops';
5253
export * from './in_top_k';
5354

5455
export {op} from './operation';

tfjs-core/src/ops/string_ops.ts

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import {ENGINE} from '../engine';
18+
import {StringTensor, Tensor} from '../tensor';
19+
import {convertToTensor} from '../tensor_util_env';
20+
21+
import {op} from './operation';
22+
23+
/**
24+
* Encodes the values of a `tf.Tensor` (of dtype `string`) to Base64.
25+
*
26+
* Given a String tensor, returns a new tensor with the values encoded into
27+
* web-safe base64 format.
28+
*
29+
* Web-safe means that the encoder uses `-` and `_` instead of `+` and `/`:
30+
*
31+
* en.wikipedia.org/wiki/Base64
32+
*
33+
* ```js
34+
* const x = tf.tensor1d(['Hello world!'], 'string');
35+
*
36+
* x.encodeBase64().print();
37+
* ```
38+
* @param str The input `tf.Tensor` of dtype `string` to encode.
39+
* @param pad Whether to add padding (`=`) to the end of the encoded string.
40+
*/
41+
/** @doc {heading: 'Operations', subheading: 'String'} */
42+
function encodeBase64_<T extends StringTensor>(
43+
str: StringTensor|Tensor, pad = false): T {
44+
const $str = convertToTensor(str, 'str', 'encodeBase64', 'string');
45+
46+
const backwardsFunc = (dy: T) => ({$str: () => decodeBase64(dy)});
47+
48+
return ENGINE.runKernel(
49+
backend => backend.encodeBase64($str, pad), {$str}, backwardsFunc);
50+
}
51+
52+
/**
53+
* Decodes the values of a `tf.Tensor` (of dtype `string`) from Base64.
54+
*
55+
* Given a String tensor of Base64 encoded values, returns a new tensor with the
56+
* decoded values.
57+
*
58+
* en.wikipedia.org/wiki/Base64
59+
*
60+
* ```js
61+
* const y = tf.scalar('SGVsbG8gd29ybGQh', 'string');
62+
*
63+
* y.decodeBase64().print();
64+
* ```
65+
* @param str The input `tf.Tensor` of dtype `string` to decode.
66+
*/
67+
/** @doc {heading: 'Operations', subheading: 'String'} */
68+
function decodeBase64_<T extends StringTensor>(str: StringTensor|Tensor): T {
69+
const $str = convertToTensor(str, 'str', 'decodeBase64', 'string');
70+
71+
const backwardsFunc = (dy: T) => ({$str: () => encodeBase64(dy)});
72+
73+
return ENGINE.runKernel(
74+
backend => backend.decodeBase64($str), {$str}, backwardsFunc);
75+
}
76+
77+
export const encodeBase64 = op({encodeBase64_});
78+
export const decodeBase64 = op({decodeBase64_});

tfjs-core/src/tensor.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2017 Google Inc. All Rights Reserved.
3+
* Copyright 2019 Google Inc. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -367,6 +367,8 @@ export interface OpHandler {
367367
fft(x: Tensor): Tensor; ifft(x: Tensor): Tensor; rfft(x: Tensor): Tensor;
368368
irfft(x: Tensor): Tensor
369369
};
370+
encodeBase64<T extends StringTensor>(x: T, pad: boolean): T;
371+
decodeBase64<T extends StringTensor>(x: T): T;
370372
}
371373

372374
// For tracking tensor creation and disposal.
@@ -1426,6 +1428,16 @@ export class Tensor<R extends Rank = Rank> {
14261428
this.throwIfDisposed();
14271429
return opHandler.spectral.irfft(this);
14281430
}
1431+
1432+
encodeBase64<T extends StringTensor>(this: T, pad = false): T {
1433+
this.throwIfDisposed();
1434+
return opHandler.encodeBase64(this, pad);
1435+
}
1436+
1437+
decodeBase64<T extends StringTensor>(this: T): T {
1438+
this.throwIfDisposed();
1439+
return opHandler.decodeBase64(this);
1440+
}
14291441
}
14301442
Object.defineProperty(Tensor, Symbol.hasInstance, {
14311443
value: (instance: Tensor) => {

0 commit comments

Comments
 (0)