Skip to content

Commit 4eea2af

Browse files
JackMoriartyliutongxuan
authored andcommitted
[Op] Implement GPU version of ApplyAdamAsync, SparseApplyAdamAsync, KvSparseApplyAdamAsync.
1 parent 38af546 commit 4eea2af

File tree

7 files changed

+1450
-847
lines changed

7 files changed

+1450
-847
lines changed

tensorflow/core/kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6227,6 +6227,7 @@ tf_kernel_library(
62276227
gpu_srcs = [
62286228
"training_ali_ops_gpu.cu.cc",
62296229
"training_ali_ops_gpu.h",
6230+
"training_ali_ops.h"
62306231
],
62316232
copts = ["-g"],
62326233
deps = [

tensorflow/core/kernels/training_ali_ops.cc

Lines changed: 387 additions & 152 deletions
Large diffs are not rendered by default.

tensorflow/core/kernels/training_ali_ops.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,30 @@ template <typename Device, typename T>
3838
struct ApplyAdamAsync {
3939
void operator()(const Device& d, typename TTypes<T>::Flat var,
4040
typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
41-
typename TTypes<T>::Flat beta1_power,
42-
typename TTypes<T>::Flat beta2_power,
41+
typename TTypes<T>::Scalar beta1_power,
42+
typename TTypes<T>::Scalar beta2_power,
4343
typename TTypes<T>::ConstScalar lr,
4444
typename TTypes<T>::ConstScalar beta1,
4545
typename TTypes<T>::ConstScalar beta2,
4646
typename TTypes<T>::ConstScalar epsilon,
4747
typename TTypes<T>::ConstFlat grad, bool use_nesterov);
4848
};
4949

50+
template <typename Device, typename T, typename Tindex>
51+
struct SparseApplyAdamAsync {
52+
Status operator()(const Device &d, typename TTypes<T>::Matrix var,
53+
typename TTypes<T>::Matrix m, typename TTypes<T>::Matrix v,
54+
typename TTypes<T>::Scalar beta1_power,
55+
typename TTypes<T>::Scalar beta2_power,
56+
typename TTypes<T>::ConstScalar lr,
57+
typename TTypes<T>::ConstScalar beta1,
58+
typename TTypes<T>::ConstScalar beta2,
59+
typename TTypes<T>::ConstScalar epsilon,
60+
typename TTypes<T>::ConstMatrix grad,
61+
typename TTypes<Tindex>::ConstVec indices_vec,
62+
bool apply_sparse_rmsprop, int64 inner_dim);
63+
};
64+
5065
} // end namespace functor
5166
} // end namespace tensorflow
5267

0 commit comments

Comments
 (0)