1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/executor.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/graph/testlib.h" 25 #include "tensorflow/core/lib/core/threadpool.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace testing { 30 namespace benchmark { 31 class State; 32 } // namespace benchmark 33 } // namespace testing 34 35 namespace tensorflow { 36 37 class Device; 38 class FunctionLibraryRuntime; 39 class ProcessFunctionLibraryRuntime; 40 struct SessionOptions; 41 class StaticDeviceMgr; 42 43 namespace test { 44 45 class Benchmark { 46 public: 47 // "device" must be either "cpu" or "gpu". Takes ownership of "g", 48 // "init", and one reference on "rendez" (if not null). 49 // 50 // old_benchmark_api: If true, the benchmark is running with older API 51 // * In the old API, the timer needs to be stopped/restarted 52 // by users. 53 // * In the new API, the timer starts automatically at the first 54 // iteration of the loop and stops after the last iteration. 55 // TODO(vyng) Remove this once we have migrated all code to newer API. 56 Benchmark(const string& device, Graph* g, 57 const SessionOptions* options = nullptr, Graph* init = nullptr, 58 Rendezvous* rendez = nullptr, const char* executor_type = "", 59 bool old_benchmark_api = true); 60 61 Benchmark(const string& device, Graph* g, bool old_benchmark_api); 62 63 ~Benchmark(); 64 65 // Executes the graph for "iters" times. 66 // This function is deprecated. Use the overload that takes 67 // `benchmark::State&` 68 // instead. 69 [[deprecated("use `Run(benchmark::State&)` instead.")]] void Run(int iters); 70 71 void Run(::testing::benchmark::State& state); 72 73 // If "g" contains send/recv nodes, before each execution, we send 74 // inputs to the corresponding recv keys in the graph, after each 75 // execution, we recv outputs from the corresponding send keys in 76 // the graph. In the benchmark, we throw away values returned by the 77 // graph. 78 // This function is deprecated. Use the overload that takes 79 // `benchmark::State&` instead. 80 [[deprecated( 81 "use `RunWithRendezvousArgs(...,benchmark::State&)` instead.")]] void 82 RunWithRendezvousArgs(const std::vector<std::pair<string, Tensor>>& inputs, 83 const std::vector<string>& outputs, int iters); 84 85 void RunWithRendezvousArgs( 86 const std::vector<std::pair<string, Tensor>>& inputs, 87 const std::vector<string>& outputs, ::testing::benchmark::State& state); 88 89 private: 90 thread::ThreadPool* pool_ = nullptr; // Not owned. 91 Device* device_ = nullptr; // Not owned. 92 Rendezvous* rendez_ = nullptr; 93 std::unique_ptr<StaticDeviceMgr> device_mgr_; 94 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 95 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 96 FunctionLibraryRuntime* flr_; // Not owned. 97 std::unique_ptr<Executor> exec_; 98 bool old_benchmark_api_; 99 100 TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); 101 }; 102 103 // Returns the rendezvous key associated with the given Send/Recv node. 104 string GetRendezvousKey(const Node* node); 105 106 } // end namespace test 107 } // end namespace tensorflow 108 109 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ 110