• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/graph/testlib.h"
25 #include "tensorflow/core/lib/core/threadpool.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace testing {
30 namespace benchmark {
31 class State;
32 }  // namespace benchmark
33 }  // namespace testing
34 
35 namespace tensorflow {
36 
37 class Device;
38 class FunctionLibraryRuntime;
39 class ProcessFunctionLibraryRuntime;
40 struct SessionOptions;
41 class StaticDeviceMgr;
42 
43 namespace test {
44 
45 class Benchmark {
46  public:
47   // "device" must be either "cpu" or "gpu".  Takes ownership of "g",
48   // "init", and one reference on "rendez" (if not null).
49   //
50   // old_benchmark_api: If true, the benchmark is running with older API
51   //   * In the old API, the timer needs to be stopped/restarted
52   //     by users.
53   //   * In the new API, the timer starts automatically at the first
54   //     iteration of the loop and stops after the last iteration.
55   // TODO(vyng) Remove this once we have migrated all code to newer API.
56   Benchmark(const string& device, Graph* g,
57             const SessionOptions* options = nullptr, Graph* init = nullptr,
58             Rendezvous* rendez = nullptr, const char* executor_type = "",
59             bool old_benchmark_api = true);
60 
61   Benchmark(const string& device, Graph* g, bool old_benchmark_api);
62 
63   ~Benchmark();
64 
65   // Executes the graph for "iters" times.
66   // This function is deprecated. Use the overload that takes
67   // `benchmark::State&`
68   // instead.
69   [[deprecated("use `Run(benchmark::State&)` instead.")]] void Run(int iters);
70 
71   void Run(::testing::benchmark::State& state);
72 
73   // If "g" contains send/recv nodes, before each execution, we send
74   // inputs to the corresponding recv keys in the graph, after each
75   // execution, we recv outputs from the corresponding send keys in
76   // the graph. In the benchmark, we throw away values returned by the
77   // graph.
78   // This function is deprecated. Use the overload that takes
79   // `benchmark::State&` instead.
80   [[deprecated(
81       "use `RunWithRendezvousArgs(...,benchmark::State&)` instead.")]] void
82   RunWithRendezvousArgs(const std::vector<std::pair<string, Tensor>>& inputs,
83                         const std::vector<string>& outputs, int iters);
84 
85   void RunWithRendezvousArgs(
86       const std::vector<std::pair<string, Tensor>>& inputs,
87       const std::vector<string>& outputs, ::testing::benchmark::State& state);
88 
89  private:
90   thread::ThreadPool* pool_ = nullptr;  // Not owned.
91   Device* device_ = nullptr;            // Not owned.
92   Rendezvous* rendez_ = nullptr;
93   std::unique_ptr<StaticDeviceMgr> device_mgr_;
94   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
95   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
96   FunctionLibraryRuntime* flr_;  // Not owned.
97   std::unique_ptr<Executor> exec_;
98   bool old_benchmark_api_;
99 
100   TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
101 };
102 
103 // Returns the rendezvous key associated with the given Send/Recv node.
104 string GetRendezvousKey(const Node* node);
105 
106 }  // end namespace test
107 }  // end namespace tensorflow
108 
109 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
110