• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
17 
18 #include <limits>
19 
20 #include "tensorflow/core/framework/cost_graph.pb.h"
21 #include "tensorflow/core/framework/step_stats.pb.h"
22 #include "tensorflow/core/grappler/clusters/cluster.h"
23 #include "tensorflow/core/grappler/costs/robust_stats.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/blocking_counter.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/public/session.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 
MeasuringCostEstimator(Cluster * cluster,int measurement_steps,int measurement_threads)33 MeasuringCostEstimator::MeasuringCostEstimator(Cluster* cluster,
34                                                int measurement_steps,
35                                                int measurement_threads)
36     : measurement_steps_(measurement_steps),
37       measurement_threads_(measurement_threads) {
38   CHECK_GE(measurement_steps, 1);
39   if (measurement_threads > 0) {
40     thread_pool_.reset(new thread::ThreadPool(
41         Env::Default(), SanitizeThreadSuffix("measurements"),
42         measurement_threads));
43   }
44   cluster_ = cluster;
45 }
46 
Initialize(const GrapplerItem & item)47 Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) {
48   feed_ = item.feed;
49   fetch_ = item.fetch;
50   return cluster_->Initialize(item);
51 }
52 
PredictCosts(const GraphDef & optimized_graph,RunMetadata * run_metadata,Costs * costs) const53 Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
54                                             RunMetadata* run_metadata,
55                                             Costs* costs) const {
56   CostGraphDef* cost_graph = nullptr;
57   if (run_metadata) {
58     cost_graph = run_metadata->mutable_cost_graph();
59   }
60   const bool running_simulation = (cluster_->type() == "virtual");
61 
62   std::vector<double> times(measurement_steps_);
63   BlockingCounter barrier(measurement_steps_);
64 
65   mutex status_mu;
66   Status status;
67 
68   auto measurement_fn = [&](const int step) {
69     const Costs::MicroSeconds start = Env::Default()->NowMicros();
70 
71     RunMetadata metadata;
72     const Status local_status =
73         cluster_->Run(optimized_graph, feed_, fetch_, &metadata);
74     {
75       mutex_lock lock(status_mu);
76       status.Update(local_status);
77     }
78     if (step < 0) {
79       // Discard the first iteration as it triggers the warmup, and therefore
80       // takes much longer than a normal step.
81       return;
82     }
83     if (!local_status.ok()) {
84       // Discard the data if the run wasn't successful.
85       barrier.DecrementCount();
86       return;
87     }
88 
89     const Costs::MicroSeconds finish = Env::Default()->NowMicros();
90     if (running_simulation) {
91       // When running simulation, return the estimated runtime, not the time it
92       // takes to run the simulation.
93       double time = 0.0;
94       for (const DeviceStepStats& stepstats :
95            metadata.step_stats().dev_stats()) {
96         for (const NodeExecStats& node_stats : stepstats.node_stats()) {
97           const double completion_time =
98               node_stats.all_end_rel_micros() + node_stats.all_start_micros();
99           time = std::max(time, completion_time * 1e3);
100         }
101       }
102       times[step] = time;
103     } else {
104       const double time = (finish - start).count() * 1e3;
105       times[step] = time;
106     }
107     if (cost_graph && (step + 1 == measurement_steps_)) {
108       metadata.mutable_cost_graph()->Swap(cost_graph);
109     }
110 
111     barrier.DecrementCount();
112   };
113 
114   // Initialize the computation and warm up TensorFlow.
115   measurement_fn(-1);
116 
117   if (!status.ok()) {
118     LOG(ERROR) << "Failed to run start measurements: "
119                << status.error_message();
120     costs->execution_time = Costs::Duration::max();
121     return status;
122   }
123 
124   // Run "measurement_steps_" and measure the time.
125   VLOG(1) << "Number of measurement steps: " << measurement_steps_;
126   if (measurement_threads_ > 0) {
127     for (int i = 0; i < measurement_steps_; ++i) {
128       thread_pool_->Schedule([i, &measurement_fn]() { measurement_fn(i); });
129     }
130     barrier.Wait();
131   } else {
132     for (int i = 0; i < measurement_steps_ && status.ok(); ++i) {
133       measurement_fn(i);
134     }
135   }
136 
137   if (!status.ok()) {
138     LOG(ERROR) << "Failed to measure graph performance: "
139                << status.error_message();
140     costs->execution_time = Costs::Duration::max();
141     return status;
142   }
143 
144   // Compute the average time of the measure steps. Use Huber statistics
145   // to filter out outliers.
146   RobustStats stats(times);
147   costs->execution_time = Costs::Duration(stats.mean());
148 
149   return Status::OK();
150 }
151 }  // end namespace grappler
152 }  // end namespace tensorflow
153