• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
18 
19 #include <cmath>
20 #include <unordered_map>
21 
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/protobuf/config.pb.h"
24 
25 namespace tensorflow {
26 class GraphDef;
27 class CostGraphDef;
28 
29 namespace grappler {
30 struct GrapplerItem;
31 
32 constexpr int64 kMemoryUnknown = -1ll;
33 constexpr int64 kZeroMemory = 0ll;
34 
35 struct DeviceInfo {
36   // Billions of operations executed per second.
37   double gigaops;
38 
39   // Bandwidth to main memory in GB per second.
40   double gb_per_sec;
41 
42   // Read bandwidth to intermediate memory in GB per second.
43   double intermediate_read_gb_per_sec;
44 
45   // Write bandwidth to intermediate memory in GB per second.
46   double intermediate_write_gb_per_sec;
47 
DeviceInfoDeviceInfo48   DeviceInfo()
49       : gigaops(INFINITY),
50         gb_per_sec(INFINITY),
51         intermediate_read_gb_per_sec(INFINITY),
52         intermediate_write_gb_per_sec(INFINITY) {}
53 
DeviceInfoDeviceInfo54   DeviceInfo(const DeviceInfo& input)
55       : gigaops(input.gigaops),
56         gb_per_sec(input.gb_per_sec),
57         intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec),
58         intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {}
59 
60   DeviceInfo(double gigaops, double gb_per_sec,
61              double intermediate_read_gb_per_sec = INFINITY,
62              double intermediate_write_gb_per_sec = INFINITY)
gigaopsDeviceInfo63       : gigaops(gigaops),
64         gb_per_sec(gb_per_sec),
65         intermediate_read_gb_per_sec(intermediate_read_gb_per_sec),
66         intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {}
67 };
68 
69 // Holds the set of things we might want to estimate or measure in Grappler.
70 // Always produce execution time. Other fields are optional depending on the
71 // estimator being used.
72 struct Costs {
73   // Returns a Costs structure with default values for all of the fields.
74   inline Costs();
75 
76   // Builds a Costs structure with all zero values, rather than unknowns.
77   static inline Costs ZeroCosts(bool inaccurate = false);
78 
79   struct MilliSeconds : std::chrono::milliseconds {
MilliSecondsCosts::MilliSeconds80     MilliSeconds() : std::chrono::milliseconds(0) {}
MilliSecondsCosts::MilliSeconds81     MilliSeconds(double d) : std::chrono::milliseconds(static_cast<int64>(d)) {}
MilliSecondsCosts::MilliSeconds82     MilliSeconds(const std::chrono::milliseconds& d)
83         : std::chrono::milliseconds(d) {}
84     MilliSeconds& operator=(const std::chrono::milliseconds& d) {
85       std::chrono::milliseconds::operator=(d);
86       return *this;
87     }
88   };
89   struct MicroSeconds : std::chrono::microseconds {
MicroSecondsCosts::MicroSeconds90     MicroSeconds() : std::chrono::microseconds(0) {}
MicroSecondsCosts::MicroSeconds91     MicroSeconds(double d) : std::chrono::microseconds(static_cast<int64>(d)) {}
MicroSecondsCosts::MicroSeconds92     MicroSeconds(const std::chrono::microseconds& d)
93         : std::chrono::microseconds(d) {}
94     MicroSeconds& operator=(const std::chrono::microseconds& d) {
95       std::chrono::microseconds::operator=(d);
96       return *this;
97     }
asMilliSecondsCosts::MicroSeconds98     MilliSeconds asMilliSeconds() const {
99       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
100     }
101   };
102   struct NanoSeconds : std::chrono::nanoseconds {
NanoSecondsCosts::NanoSeconds103     NanoSeconds() : std::chrono::nanoseconds(0) {}
NanoSecondsCosts::NanoSeconds104     NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64>(d)) {}
NanoSecondsCosts::NanoSeconds105     NanoSeconds(const std::chrono::nanoseconds& d)
106         : std::chrono::nanoseconds(d) {}
107     NanoSeconds& operator=(const std::chrono::nanoseconds& d) {
108       std::chrono::nanoseconds::operator=(d);
109       return *this;
110     }
asMicroSecondsCosts::NanoSeconds111     MicroSeconds asMicroSeconds() const {
112       return std::chrono::duration_cast<std::chrono::microseconds>(*this);
113     }
asMilliSecondsCosts::NanoSeconds114     MilliSeconds asMilliSeconds() const {
115       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
116     }
infinityCosts::NanoSeconds117     static NanoSeconds infinity() {
118       return NanoSeconds(std::chrono::nanoseconds::max());
119     }
120   };
121 
122   // We store all our times in nanoseconds. If needs be, we can always switch to
123   // picoseconds in the future by updating this typedef.
124   typedef NanoSeconds Duration;
125 
126   // Overall cost of running the graph; latency.
127   Duration execution_time;
128 
129   // Computation cost of running the graph.
130   Duration compute_time;
131 
132   // Memory access cost of running the graph.
133   Duration memory_time;
134 
135   // Intermediate memory access cost of running the graph
136   Duration intermediate_memory_time;
137   Duration intermediate_memory_read_time;   // Intermediate memory read cost.
138   Duration intermediate_memory_write_time;  // Intermediate memory write cost.
139 
140   // This field can be a very pessimistic estimate of the main memory
141   // requirements of a graph. For example, it might assume that all activations
142   // are live for all of a graph's execution.
143   int64 max_memory;  // Maximum main memory requirement in bytes over all ops.
144   int64 persistent_memory;
145   int64 temporary_memory;
146 
147   // These fields are used for TPU-related estimations. They are per-op
148   // maximums, so each op is evaluated independently, but we want the maximum of
149   // the value over all ops.
150   int64 max_per_op_buffers;    // Sum of all buffers used by the ops.
151   int64 max_per_op_streaming;  // Ignore largest input buffer, assuming it
152                                // streams from main memory.
153 
154   // Number of ops included in this Costs in total.
155   // Default initialized to be one.
156   int64 num_ops_total = 1;
157   // If the time estimation is inaccurate.
158   bool inaccurate = false;
159   // Number of ops that are estimated with unknown shapes.
160   int64 num_ops_with_unknown_shapes = 0;
161   // TODO(pcma): include a counter for total inaccurate ops and counters for
162   // other reasons causing the inaccuracy
163 
164   // Max possible memory usage per device.
165   std::unordered_map<string, uint64> estimated_max_memory_per_device;
166 };
167 
168 inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) {
169   os << d.count() << "ms";
170   return os;
171 }
172 inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) {
173   os << d.count() << "us";
174   return os;
175 }
176 inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) {
177   os << d.count() << "ns";
178   return os;
179 }
180 
Costs()181 Costs::Costs() {
182   execution_time = Duration::zero();
183   compute_time = Duration::zero();
184   memory_time = Duration::zero();
185   intermediate_memory_time = Duration::zero();
186   max_memory = kMemoryUnknown;
187   persistent_memory = kMemoryUnknown;
188   temporary_memory = kMemoryUnknown;
189   max_per_op_buffers = kMemoryUnknown;
190   max_per_op_streaming = kMemoryUnknown;
191 }
192 
ZeroCosts(bool inaccurate)193 Costs Costs::ZeroCosts(bool inaccurate) {
194   Costs costs;
195   costs.execution_time = Duration::zero();
196   costs.compute_time = Duration::zero();
197   costs.memory_time = Duration::zero();
198   costs.intermediate_memory_time = Duration::zero();
199   costs.max_memory = kZeroMemory;
200   costs.persistent_memory = kZeroMemory;
201   costs.temporary_memory = kZeroMemory;
202   costs.max_per_op_buffers = kZeroMemory;
203   costs.max_per_op_streaming = kZeroMemory;
204   costs.inaccurate = inaccurate;
205   return costs;
206 }
207 
208 Costs CombineCosts(const Costs& left, const Costs& right);
209 
210 // Multiplies Costs by a scalar.
211 // Equivalent to applying CombineCosts "multiplier" times.
212 Costs MultiplyCosts(const Costs& costs, int multiplier);
213 
214 // Given a GrapperItem and an optimized implementation of the corresponding
215 // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
216 // running the graph.
217 class CostEstimator {
218  public:
~CostEstimator()219   virtual ~CostEstimator() {}
220 
221   // Initializes the estimator for the specified grappler item.
222   // The estimator shouldn't be used if this function returns any status other
223   // that OK.
224   virtual Status Initialize(const GrapplerItem& item) = 0;
225 
226   // Predicts the cost of running the given optimized version of the grappler
227   // item.
228   // If a RunMetadata is passed, it will be populated with detailed information
229   // about the cost of running each operation of the optimized graph.
230   // if a double value is passed, it will be set to a value that reflects the
231   // overall cost of running the graph (e.g. the latency of the computation).
232   // Returns a status that indicate is the performance could be estimated or
233   // not.
234   virtual Status PredictCosts(const GraphDef& optimized_graph,
235                               RunMetadata* run_metadata, Costs* cost) const = 0;
236 };
237 
238 }  // end namespace grappler
239 }  // end namespace tensorflow
240 
241 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
242