@@ -16,20 +16,30 @@ limitations under the License.
1616#include " tensorflow_serving/servables/tensorflow/saved_model_warmup_util.h"
1717
1818#include < algorithm>
19+ #include < cstdint>
1920#include < functional>
2021#include < memory>
2122#include < utility>
2223
2324#include " google/protobuf/wrappers.pb.h"
25+ #include " absl/base/thread_annotations.h"
26+ #include " absl/log/log.h"
27+ #include " absl/status/status.h"
2428#include " tensorflow/cc/saved_model/constants.h"
2529#include " xla/tsl/platform/errors.h"
2630#include " tensorflow/core/kernels/batching_util/warmup.h"
2731#include " tensorflow/core/lib/core/errors.h"
28- #include " tensorflow/core/lib/io/path.h"
2932#include " tensorflow/core/lib/io/record_reader.h"
3033#include " tensorflow/core/lib/monitoring/sampler.h"
34+ #include " tensorflow/core/platform/env.h"
35+ #include " tensorflow/core/platform/env_time.h"
36+ #include " tensorflow/core/platform/file_system.h"
3137#include " tensorflow/core/platform/mutex.h"
32- #include " tensorflow/core/platform/status.h"
38+ #include " tensorflow/core/platform/path.h"
39+ #include " tensorflow/core/platform/strcat.h"
40+ #include " tensorflow/core/platform/tstring.h"
41+ #include " tensorflow/core/platform/types.h"
42+ #include " tensorflow_serving/util/executor.h"
3343#include " tensorflow_serving/util/threadpool_executor.h"
3444
3545namespace tensorflow {
@@ -58,22 +68,9 @@ uint64_t GetLatencyMicroseconds(const uint64_t start_microseconds) {
5868constexpr char WarmupConsts::kRequestsFileName [];
5969constexpr int WarmupConsts::kMaxNumRecords ;
6070
61- absl::Status RunSavedModelWarmup (
71+ absl::Status RunSavedModelWarmupUntracked (
6272 const ModelWarmupOptions& model_warmup_options, const string export_dir,
6373 std::function<absl::Status(PredictionLog)> warmup_request_executor) {
64- WarmupStateRegistry::Handle warmup_handle;
65- auto per_model_data = std::make_unique<WarmupStateRegistry::PerModelData>();
66- per_model_data->warmup_all_batch_sizes =
67- model_warmup_options.enable_all_batch_sizes_warmup ();
68- if (!model_warmup_options.model_name ().empty ()) {
69- auto h = GetGlobalWarmupStateRegistry ().Register (
70- {model_warmup_options.model_name (),
71- model_warmup_options.model_version ()},
72- std::move (per_model_data));
73- TF_RETURN_IF_ERROR (h.status ());
74- warmup_handle = std::move (h.value ());
75- }
76-
7774 const uint64_t start_microseconds = EnvTime::NowMicros ();
7875 const string warmup_path =
7976 io::JoinPath (export_dir, kSavedModelAssetsExtraDirectory ,
@@ -237,6 +234,26 @@ absl::Status RunSavedModelWarmup(
237234 return absl::OkStatus ();
238235}
239236
237+ absl::Status RunSavedModelWarmup (
238+ const ModelWarmupOptions& model_warmup_options, const string export_dir,
239+ std::function<absl::Status(PredictionLog)> warmup_request_executor) {
240+ WarmupStateRegistry::Handle warmup_handle;
241+ auto per_model_data = std::make_unique<WarmupStateRegistry::PerModelData>();
242+ per_model_data->warmup_all_batch_sizes =
243+ model_warmup_options.enable_all_batch_sizes_warmup ();
244+ if (!model_warmup_options.model_name ().empty ()) {
245+ auto h = GetGlobalWarmupStateRegistry ().Register (
246+ {model_warmup_options.model_name (),
247+ model_warmup_options.model_version ()},
248+ std::move (per_model_data));
249+ TF_RETURN_IF_ERROR (h.status ());
250+ warmup_handle = std::move (h.value ());
251+ }
252+
253+ return RunSavedModelWarmupUntracked (model_warmup_options, export_dir,
254+ warmup_request_executor);
255+ }
256+
240257} // namespace internal
241258} // namespace serving
242259} // namespace tensorflow
0 commit comments