1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
17
18 #include <limits>
19
20 #include "tensorflow/core/framework/cost_graph.pb.h"
21 #include "tensorflow/core/framework/step_stats.pb.h"
22 #include "tensorflow/core/grappler/clusters/cluster.h"
23 #include "tensorflow/core/grappler/costs/robust_stats.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/blocking_counter.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/public/session.h"
29
30 namespace tensorflow {
31 namespace grappler {
32
MeasuringCostEstimator(Cluster * cluster,int measurement_steps,int measurement_threads)33 MeasuringCostEstimator::MeasuringCostEstimator(Cluster* cluster,
34 int measurement_steps,
35 int measurement_threads)
36 : measurement_steps_(measurement_steps),
37 measurement_threads_(measurement_threads) {
38 CHECK_GE(measurement_steps, 1);
39 if (measurement_threads > 0) {
40 thread_pool_.reset(new thread::ThreadPool(
41 Env::Default(), SanitizeThreadSuffix("measurements"),
42 measurement_threads));
43 }
44 cluster_ = cluster;
45 }
46
Initialize(const GrapplerItem & item)47 Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) {
48 feed_ = item.feed;
49 fetch_ = item.fetch;
50 return cluster_->Initialize(item);
51 }
52
PredictCosts(const GraphDef & optimized_graph,RunMetadata * run_metadata,Costs * costs) const53 Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
54 RunMetadata* run_metadata,
55 Costs* costs) const {
56 CostGraphDef* cost_graph = nullptr;
57 if (run_metadata) {
58 cost_graph = run_metadata->mutable_cost_graph();
59 }
60 const bool running_simulation = (cluster_->type() == "virtual");
61
62 std::vector<double> times(measurement_steps_);
63 BlockingCounter barrier(measurement_steps_);
64
65 mutex status_mu;
66 Status status;
67
68 auto measurement_fn = [&](const int step) {
69 const Costs::MicroSeconds start = Env::Default()->NowMicros();
70
71 RunMetadata metadata;
72 const Status local_status =
73 cluster_->Run(optimized_graph, feed_, fetch_, &metadata);
74 {
75 mutex_lock lock(status_mu);
76 status.Update(local_status);
77 }
78 if (step < 0) {
79 // Discard the first iteration as it triggers the warmup, and therefore
80 // takes much longer than a normal step.
81 return;
82 }
83 if (!local_status.ok()) {
84 // Discard the data if the run wasn't successful.
85 barrier.DecrementCount();
86 return;
87 }
88
89 const Costs::MicroSeconds finish = Env::Default()->NowMicros();
90 if (running_simulation) {
91 // When running simulation, return the estimated runtime, not the time it
92 // takes to run the simulation.
93 double time = 0.0;
94 for (const DeviceStepStats& stepstats :
95 metadata.step_stats().dev_stats()) {
96 for (const NodeExecStats& node_stats : stepstats.node_stats()) {
97 const double completion_time =
98 node_stats.all_end_rel_micros() + node_stats.all_start_micros();
99 time = std::max(time, completion_time * 1e3);
100 }
101 }
102 times[step] = time;
103 } else {
104 const double time = (finish - start).count() * 1e3;
105 times[step] = time;
106 }
107 if (cost_graph && (step + 1 == measurement_steps_)) {
108 metadata.mutable_cost_graph()->Swap(cost_graph);
109 }
110
111 barrier.DecrementCount();
112 };
113
114 // Initialize the computation and warm up TensorFlow.
115 measurement_fn(-1);
116
117 if (!status.ok()) {
118 LOG(ERROR) << "Failed to run start measurements: "
119 << status.error_message();
120 costs->execution_time = Costs::Duration::max();
121 return status;
122 }
123
124 // Run "measurement_steps_" and measure the time.
125 VLOG(1) << "Number of measurement steps: " << measurement_steps_;
126 if (measurement_threads_ > 0) {
127 for (int i = 0; i < measurement_steps_; ++i) {
128 thread_pool_->Schedule([i, &measurement_fn]() { measurement_fn(i); });
129 }
130 barrier.Wait();
131 } else {
132 for (int i = 0; i < measurement_steps_ && status.ok(); ++i) {
133 measurement_fn(i);
134 }
135 }
136
137 if (!status.ok()) {
138 LOG(ERROR) << "Failed to measure graph performance: "
139 << status.error_message();
140 costs->execution_time = Costs::Duration::max();
141 return status;
142 }
143
144 // Compute the average time of the measure steps. Use Huber statistics
145 // to filter out outliers.
146 RobustStats stats(times);
147 costs->execution_time = Costs::Duration(stats.mean());
148
149 return Status::OK();
150 }
151 } // end namespace grappler
152 } // end namespace tensorflow
153