• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/executor_factory.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/local_device.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/op_segment.h"
29 #include "tensorflow/core/framework/versions.pb.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/lib/core/notification.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/byte_order.h"
37 #include "tensorflow/core/platform/cpu_info.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/public/session_options.h"
41 #include "tensorflow/core/public/version.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 
44 namespace tensorflow {
45 namespace test {
46 
47 // TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
Benchmark(const string & device,Graph * g,const SessionOptions * options,Graph * init,Rendezvous * rendez,const char * executor_type,bool old_benchmark_api)48 Benchmark::Benchmark(const string& device, Graph* g,
49                      const SessionOptions* options, Graph* init,
50                      Rendezvous* rendez, const char* executor_type,
51                      bool old_benchmark_api) {
52   auto cleanup = gtl::MakeCleanup([g, init]() {
53     delete g;
54     delete init;
55   });
56 
57   SessionOptions default_options;
58   if (!options) {
59     options = &default_options;
60   }
61 
62   CHECK(!old_benchmark_api) << "Expected new API only";
63 
64   string t = absl::AsciiStrToUpper(device);
65   // Allow NewDevice to allocate a new threadpool with different number of
66   // threads for each new benchmark.
67   LocalDevice::set_use_global_threadpool(false);
68 
69   device_mgr_ = std::make_unique<StaticDeviceMgr>(
70       DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"));
71   device_ = device_mgr_->ListDevices()[0];
72   CHECK(device_) << "Could not create a " << device << " device";
73 
74   pool_ =
75       new thread::ThreadPool(options->env, "blocking", port::MaxParallelism());
76 
77   auto runner = [this](std::function<void()> closure) {
78     pool_->Schedule(closure);
79   };
80 
81   if (rendez == nullptr) {
82     rendez_ = NewLocalRendezvous();
83   } else {
84     rendez_ = rendez;
85   }
86 
87   const int graph_def_version = g->versions().producer();
88 
89   flib_def_ = std::make_unique<FunctionLibraryDefinition>(g->flib_def());
90 
91   pflr_ = std::unique_ptr<ProcessFunctionLibraryRuntime>(
92       new ProcessFunctionLibraryRuntime(
93           device_mgr_.get(), Env::Default(), nullptr, graph_def_version,
94           flib_def_.get(), OptimizerOptions(), pool_, nullptr, nullptr,
95           Rendezvous::Factory()));
96 
97   flr_ = pflr_->GetFLR(device_->name());
98 
99   LocalExecutorParams params;
100   params.device = device_;
101   params.function_library = flr_;
102   params.create_kernel = [this, graph_def_version](
103                              const std::shared_ptr<const NodeProperties>& props,
104                              OpKernel** kernel) {
105     return CreateNonCachedKernel(device_, flr_, props, graph_def_version,
106                                  kernel);
107   };
108   params.delete_kernel = [](OpKernel* kernel) {
109     DeleteNonCachedKernel(kernel);
110   };
111 
112   if (init) {
113     std::unique_ptr<Executor> init_exec;
114     TF_CHECK_OK(NewExecutor(executor_type, params, *init, &init_exec));
115     Executor::Args args;
116     args.rendezvous = rendez_;
117     args.runner = runner;
118     TF_CHECK_OK(init_exec->Run(args));
119   }
120 
121   TF_CHECK_OK(NewExecutor(executor_type, params, *g, &exec_));
122 }
123 
Benchmark(const string & device,Graph * g,bool old_benchmark_api)124 Benchmark::Benchmark(const string& device, Graph* g, bool old_benchmark_api)
125     : Benchmark(device, g, nullptr, nullptr, nullptr, "", old_benchmark_api) {}
126 
~Benchmark()127 Benchmark::~Benchmark() {
128   if (device_) {
129     rendez_->Unref();
130     // We delete `exec_` before `device_mgr_` because the `exec_` destructor may
131     // run kernel destructors that may attempt to access state borrowed from
132     // `device_mgr_`, such as the resource manager.
133     exec_.reset();
134     pflr_.reset();
135     device_mgr_.reset();
136     delete pool_;
137   }
138 }
139 
Run(benchmark::State & state)140 void Benchmark::Run(benchmark::State& state) {
141   RunWithRendezvousArgs({}, {}, state);
142 }
143 
GetRendezvousKey(const Node * node)144 string GetRendezvousKey(const Node* node) {
145   string send_device;
146   TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device));
147   string recv_device;
148   TF_CHECK_OK(GetNodeAttr(node->attrs(), "recv_device", &recv_device));
149   string tensor_name;
150   TF_CHECK_OK(GetNodeAttr(node->attrs(), "tensor_name", &tensor_name));
151   uint64 send_device_incarnation;
152   TF_CHECK_OK(
153       GetNodeAttr(node->attrs(), "send_device_incarnation",
154                   reinterpret_cast<int64_t*>(&send_device_incarnation)));
155   return Rendezvous::CreateKey(send_device, send_device_incarnation,
156                                recv_device, tensor_name, FrameAndIter(0, 0));
157 }
158 
RunWithRendezvousArgs(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & outputs,benchmark::State & state)159 void Benchmark::RunWithRendezvousArgs(
160     const std::vector<std::pair<string, Tensor>>& inputs,
161     const std::vector<string>& outputs, benchmark::State& state) {
162   if (!device_ || state.max_iterations == 0) {
163     return;
164   }
165   Tensor unused;  // In benchmark, we don't care the return value.
166   bool is_dead;
167 
168   // Warm up
169   Executor::Args args;
170   args.rendezvous = rendez_;
171   args.runner = [this](std::function<void()> closure) {
172     pool_->Schedule(closure);
173   };
174   static const int kWarmupRuns = 3;
175   for (int i = 0; i < kWarmupRuns; ++i) {
176     for (const auto& p : inputs) {
177       Rendezvous::ParsedKey parsed;
178       TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
179       TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
180     }
181     TF_CHECK_OK(exec_->Run(args));
182     for (const string& key : outputs) {
183       Rendezvous::ParsedKey parsed;
184       TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
185       TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
186     }
187   }
188   TF_CHECK_OK(device_->Sync());
189   VLOG(3) << kWarmupRuns << " warmup runs done.";
190 
191   // Benchmark loop. Timer starts automatically at the beginning of the loop
192   // and ends automatically after the last iteration.
193   for (auto s : state) {
194     for (const auto& p : inputs) {
195       Rendezvous::ParsedKey parsed;
196       TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
197       TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
198     }
199     TF_CHECK_OK(exec_->Run(args));
200     for (const string& key : outputs) {
201       Rendezvous::ParsedKey parsed;
202       TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
203       TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
204     }
205   }
206   TF_CHECK_OK(device_->Sync());
207 }
208 
209 }  // end namespace test
210 }  // end namespace tensorflow
211