@@ -38,15 +38,30 @@ template <typename Device, typename T>
3838struct ApplyAdamAsync {
3939 void operator ()(const Device& d, typename TTypes<T>::Flat var,
4040 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
41- typename TTypes<T>::Flat beta1_power,
42- typename TTypes<T>::Flat beta2_power,
41+ typename TTypes<T>::Scalar beta1_power,
42+ typename TTypes<T>::Scalar beta2_power,
4343 typename TTypes<T>::ConstScalar lr,
4444 typename TTypes<T>::ConstScalar beta1,
4545 typename TTypes<T>::ConstScalar beta2,
4646 typename TTypes<T>::ConstScalar epsilon,
4747 typename TTypes<T>::ConstFlat grad, bool use_nesterov);
4848};
4949
50+ template <typename Device, typename T, typename Tindex>
51+ struct SparseApplyAdamAsync {
52+ Status operator ()(const Device &d, typename TTypes<T>::Matrix var,
53+ typename TTypes<T>::Matrix m, typename TTypes<T>::Matrix v,
54+ typename TTypes<T>::Scalar beta1_power,
55+ typename TTypes<T>::Scalar beta2_power,
56+ typename TTypes<T>::ConstScalar lr,
57+ typename TTypes<T>::ConstScalar beta1,
58+ typename TTypes<T>::ConstScalar beta2,
59+ typename TTypes<T>::ConstScalar epsilon,
60+ typename TTypes<T>::ConstMatrix grad,
61+ typename TTypes<Tindex>::ConstVec indices_vec,
62+ bool apply_sparse_rmsprop, int64 inner_dim);
63+ };
64+
5065} // end namespace functor
5166} // end namespace tensorflow
5267
0 commit comments