1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
17
18 #include <vector>
19
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/executor_factory.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/local_device.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/op_segment.h"
29 #include "tensorflow/core/framework/versions.pb.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/lib/core/notification.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/byte_order.h"
37 #include "tensorflow/core/platform/cpu_info.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/public/session_options.h"
41 #include "tensorflow/core/public/version.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43
44 namespace tensorflow {
45 namespace test {
46
47 // TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
Benchmark(const string & device,Graph * g,const SessionOptions * options,Graph * init,Rendezvous * rendez,const char * executor_type,bool old_benchmark_api)48 Benchmark::Benchmark(const string& device, Graph* g,
49 const SessionOptions* options, Graph* init,
50 Rendezvous* rendez, const char* executor_type,
51 bool old_benchmark_api) {
52 auto cleanup = gtl::MakeCleanup([g, init]() {
53 delete g;
54 delete init;
55 });
56
57 SessionOptions default_options;
58 if (!options) {
59 options = &default_options;
60 }
61
62 CHECK(!old_benchmark_api) << "Expected new API only";
63
64 string t = absl::AsciiStrToUpper(device);
65 // Allow NewDevice to allocate a new threadpool with different number of
66 // threads for each new benchmark.
67 LocalDevice::set_use_global_threadpool(false);
68
69 device_mgr_ = std::make_unique<StaticDeviceMgr>(
70 DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"));
71 device_ = device_mgr_->ListDevices()[0];
72 CHECK(device_) << "Could not create a " << device << " device";
73
74 pool_ =
75 new thread::ThreadPool(options->env, "blocking", port::MaxParallelism());
76
77 auto runner = [this](std::function<void()> closure) {
78 pool_->Schedule(closure);
79 };
80
81 if (rendez == nullptr) {
82 rendez_ = NewLocalRendezvous();
83 } else {
84 rendez_ = rendez;
85 }
86
87 const int graph_def_version = g->versions().producer();
88
89 flib_def_ = std::make_unique<FunctionLibraryDefinition>(g->flib_def());
90
91 pflr_ = std::unique_ptr<ProcessFunctionLibraryRuntime>(
92 new ProcessFunctionLibraryRuntime(
93 device_mgr_.get(), Env::Default(), nullptr, graph_def_version,
94 flib_def_.get(), OptimizerOptions(), pool_, nullptr, nullptr,
95 Rendezvous::Factory()));
96
97 flr_ = pflr_->GetFLR(device_->name());
98
99 LocalExecutorParams params;
100 params.device = device_;
101 params.function_library = flr_;
102 params.create_kernel = [this, graph_def_version](
103 const std::shared_ptr<const NodeProperties>& props,
104 OpKernel** kernel) {
105 return CreateNonCachedKernel(device_, flr_, props, graph_def_version,
106 kernel);
107 };
108 params.delete_kernel = [](OpKernel* kernel) {
109 DeleteNonCachedKernel(kernel);
110 };
111
112 if (init) {
113 std::unique_ptr<Executor> init_exec;
114 TF_CHECK_OK(NewExecutor(executor_type, params, *init, &init_exec));
115 Executor::Args args;
116 args.rendezvous = rendez_;
117 args.runner = runner;
118 TF_CHECK_OK(init_exec->Run(args));
119 }
120
121 TF_CHECK_OK(NewExecutor(executor_type, params, *g, &exec_));
122 }
123
Benchmark(const string & device,Graph * g,bool old_benchmark_api)124 Benchmark::Benchmark(const string& device, Graph* g, bool old_benchmark_api)
125 : Benchmark(device, g, nullptr, nullptr, nullptr, "", old_benchmark_api) {}
126
~Benchmark()127 Benchmark::~Benchmark() {
128 if (device_) {
129 rendez_->Unref();
130 // We delete `exec_` before `device_mgr_` because the `exec_` destructor may
131 // run kernel destructors that may attempt to access state borrowed from
132 // `device_mgr_`, such as the resource manager.
133 exec_.reset();
134 pflr_.reset();
135 device_mgr_.reset();
136 delete pool_;
137 }
138 }
139
Run(benchmark::State & state)140 void Benchmark::Run(benchmark::State& state) {
141 RunWithRendezvousArgs({}, {}, state);
142 }
143
GetRendezvousKey(const Node * node)144 string GetRendezvousKey(const Node* node) {
145 string send_device;
146 TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device));
147 string recv_device;
148 TF_CHECK_OK(GetNodeAttr(node->attrs(), "recv_device", &recv_device));
149 string tensor_name;
150 TF_CHECK_OK(GetNodeAttr(node->attrs(), "tensor_name", &tensor_name));
151 uint64 send_device_incarnation;
152 TF_CHECK_OK(
153 GetNodeAttr(node->attrs(), "send_device_incarnation",
154 reinterpret_cast<int64_t*>(&send_device_incarnation)));
155 return Rendezvous::CreateKey(send_device, send_device_incarnation,
156 recv_device, tensor_name, FrameAndIter(0, 0));
157 }
158
RunWithRendezvousArgs(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & outputs,benchmark::State & state)159 void Benchmark::RunWithRendezvousArgs(
160 const std::vector<std::pair<string, Tensor>>& inputs,
161 const std::vector<string>& outputs, benchmark::State& state) {
162 if (!device_ || state.max_iterations == 0) {
163 return;
164 }
165 Tensor unused; // In benchmark, we don't care the return value.
166 bool is_dead;
167
168 // Warm up
169 Executor::Args args;
170 args.rendezvous = rendez_;
171 args.runner = [this](std::function<void()> closure) {
172 pool_->Schedule(closure);
173 };
174 static const int kWarmupRuns = 3;
175 for (int i = 0; i < kWarmupRuns; ++i) {
176 for (const auto& p : inputs) {
177 Rendezvous::ParsedKey parsed;
178 TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
179 TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
180 }
181 TF_CHECK_OK(exec_->Run(args));
182 for (const string& key : outputs) {
183 Rendezvous::ParsedKey parsed;
184 TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
185 TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
186 }
187 }
188 TF_CHECK_OK(device_->Sync());
189 VLOG(3) << kWarmupRuns << " warmup runs done.";
190
191 // Benchmark loop. Timer starts automatically at the beginning of the loop
192 // and ends automatically after the last iteration.
193 for (auto s : state) {
194 for (const auto& p : inputs) {
195 Rendezvous::ParsedKey parsed;
196 TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
197 TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
198 }
199 TF_CHECK_OK(exec_->Run(args));
200 for (const string& key : outputs) {
201 Rendezvous::ParsedKey parsed;
202 TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
203 TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
204 }
205 }
206 TF_CHECK_OK(device_->Sync());
207 }
208
209 } // end namespace test
210 } // end namespace tensorflow
211