• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
18 
19 #include "absl/types/span.h"
20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace xla {
31 
32 // HloCostAnalysis traverses an HLO graph and calculates the amount of
33 // computations required for the graph. Each HLO instruction handler provides
34 // the computation cost of the instruction, and the values are accumulated
35 // during the traversal for the entire graph. We treat normal floating point
36 // operations separately from transcendental operations.
37 class HloCostAnalysis : public ConstDfsHloVisitor {
38  public:
39   // Each HLO is associated to a vector of properties with the indices given
40   // below. Sub-classes can add further properties.
41   // MSVC 14.0 limitation requires the consts.
42   typedef std::map<string, float> Properties;
43   static constexpr const char kFlopsKey[] = "flops";
44   static constexpr const char kTranscendentalsKey[] = "transcendentals";
45   static constexpr const char kBytesAccessedKey[] = "bytes accessed";
46   static constexpr const char kOptimalSecondsKey[] = "optimal_seconds";
47 
48   // shape_size is a function which returns the size in bytes of the top-level
49   // buffer of a shape.
50   using ShapeSizeFunction = std::function<int64(const Shape&)>;
51   explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
52 
53   Status HandleElementwiseUnary(const HloInstruction* hlo) override;
54   Status HandleElementwiseBinary(const HloInstruction* hlo) override;
55   Status HandleConstant(const HloInstruction* constant) override;
56   Status HandleIota(const HloInstruction* iota) override;
57   Status HandleGetTupleElement(
58       const HloInstruction* get_tuple_element) override;
59   Status HandleSelect(const HloInstruction* hlo) override;
60   Status HandleTupleSelect(const HloInstruction* hlo) override;
61   Status HandleCompare(const HloInstruction* compare) override;
62   Status HandleClamp(const HloInstruction* clamp) override;
63   Status HandleReducePrecision(const HloInstruction* hlo) override;
64   Status HandleConcatenate(const HloInstruction* concatenate) override;
65   Status HandleCopyStart(const HloInstruction* send) override;
66   Status HandleCopyDone(const HloInstruction* send_done) override;
67   Status HandleSend(const HloInstruction* send) override;
68   Status HandleSendDone(const HloInstruction* send_done) override;
69   Status HandleRecv(const HloInstruction* recv) override;
70   Status HandleRecvDone(const HloInstruction* recv_done) override;
71   Status HandleConvert(const HloInstruction* convert) override;
72   Status HandleCopy(const HloInstruction* copy) override;
73   Status HandleDomain(const HloInstruction* domain) override;
74   Status HandleDot(const HloInstruction* dot) override;
75   Status HandleConvolution(const HloInstruction* convolution) override;
76   Status HandleFft(const HloInstruction* fft) override;
77   Status HandleTriangularSolve(const HloInstruction* hlo) override;
78   Status HandleCholesky(const HloInstruction* hlo) override;
79   Status HandleAllGather(const HloInstruction* hlo) override;
80   Status HandleAllGatherStart(const HloInstruction* hlo) override;
81   Status HandleAllGatherDone(const HloInstruction* hlo) override;
82   Status HandleAllReduce(const HloInstruction* crs) override;
83   Status HandleReduceScatter(const HloInstruction* hlo) override;
84   Status HandleAllReduceStart(const HloInstruction* hlo) override;
85   Status HandleAllReduceDone(const HloInstruction* hlo) override;
86   Status HandleAllToAll(const HloInstruction* hlo) override;
87   Status HandleCollectivePermute(const HloInstruction* hlo) override;
88   Status HandleCollectivePermuteStart(const HloInstruction* hlo) override;
89   Status HandleCollectivePermuteDone(const HloInstruction* hlo) override;
90   Status HandleReplicaId(const HloInstruction* hlo) override;
91   Status HandlePartitionId(const HloInstruction* hlo) override;
92   Status HandleInfeed(const HloInstruction* infeed) override;
93   Status HandleOutfeed(const HloInstruction* outfeed) override;
94   Status HandleRng(const HloInstruction* random) override;
95   Status HandleRngBitGenerator(const HloInstruction* random) override;
96   Status HandleRngGetAndUpdateState(const HloInstruction* random) override;
97   Status HandleReverse(const HloInstruction* reverse) override;
98   Status HandleSort(const HloInstruction* sort) override;
99   Status HandleParameter(const HloInstruction* parameter) override;
100   Status HandleReduce(const HloInstruction* reduce) override;
101   Status HandleBatchNormTraining(
102       const HloInstruction* batch_norm_training) override;
103   Status HandleBatchNormInference(
104       const HloInstruction* batch_norm_inference) override;
105   Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override;
106   Status HandleFusion(const HloInstruction* fusion) override;
107   Status HandleCall(const HloInstruction* call) override;
108   Status HandleCustomCall(const HloInstruction* custom_call) override;
109   Status HandleSlice(const HloInstruction* slice) override;
110   Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override;
111   Status HandleDynamicUpdateSlice(
112       const HloInstruction* dynamic_update_slice) override;
113   Status HandleTuple(const HloInstruction* tuple) override;
114   Status HandleMap(const HloInstruction* map) override;
115   Status HandleReduceWindow(const HloInstruction* reduce_window) override;
116   Status HandleSelectAndScatter(const HloInstruction* instruction) override;
117   Status HandleBitcast(const HloInstruction* bitcast) override;
118   Status HandleBroadcast(const HloInstruction* broadcast) override;
119   Status HandlePad(const HloInstruction* pad) override;
120   Status HandleReshape(const HloInstruction* reshape) override;
121   Status HandleDynamicReshape(const HloInstruction* reshape) override;
122   Status HandleAddDependency(const HloInstruction* add_dependency) override;
123   Status HandleAfterAll(const HloInstruction* token) override;
124   Status HandleTranspose(const HloInstruction* transpose) override;
125   Status HandleWhile(const HloInstruction* xla_while) override;
126   Status HandleConditional(const HloInstruction* conditional) override;
127   Status HandleGather(const HloInstruction* gather) override;
128   Status HandleScatter(const HloInstruction* scatter) override;
129   Status HandleGetDimensionSize(const HloInstruction* get_size) override;
130   Status HandleSetDimensionSize(const HloInstruction* set_size) override;
131   Status FinishVisit(const HloInstruction* root) override;
132 
133   Status Preprocess(const HloInstruction* hlo) override;
134   Status Postprocess(const HloInstruction* hlo) override;
135 
136   // Decorates shape_size_ by returning 0 immediately if the shape does not have
137   // a layout.
138   int64 GetShapeSize(const Shape& shape) const;
139 
140   // Set the rates used to calculate the time taken by the computation. These
141   // need to be set before visiting starts.
set_flops_per_second(float value)142   void set_flops_per_second(float value) {
143     per_second_rates_[kFlopsKey] = value;
144   }
set_transcendentals_per_second(float value)145   void set_transcendentals_per_second(float value) {
146     per_second_rates_[kTranscendentalsKey] = value;
147   }
set_bytes_per_second(float value)148   void set_bytes_per_second(float value) {
149     per_second_rates_[kBytesAccessedKey] = value;
150   }
151 
152   // Returns properties for the computation.
153   float flop_count() const;
154   float transcendental_count() const;
155   float bytes_accessed() const;
156   float optimal_seconds() const;
157 
158   // Returns the respective cost computed for a particular HLO instruction, or 0
159   // if the HLO was not found to have a cost in the analysis.
160   //
161   // Note that the cost for sub HLO instructions are also returned if asked. For
162   // example, body and condition of a while, fused instructions within a
163   // fusion, or the add instruction of a reduce.
164   int64 flop_count(const HloInstruction& hlo) const;
165   int64 transcendental_count(const HloInstruction& hlo) const;
166   int64 bytes_accessed(const HloInstruction& hlo) const;
167   int64 operand_bytes_accessed(const HloInstruction& hlo, int64_t operand_num,
168                                ShapeIndex index = {}) const;
169   int64 output_bytes_accessed(const HloInstruction& hlo,
170                               ShapeIndex index = {}) const;
171   float optimal_seconds(const HloInstruction& hlo) const;
172 
173   // Get bytes read/written by this HLO. If memory_space is provided, it returns
174   // the bytes read/written from/to the given memory space only.
175   int64 GetBytesRead(const HloInstruction& hlo,
176                      absl::optional<int64> memory_space = absl::nullopt) const;
177   int64 GetBytesWritten(
178       const HloInstruction& hlo,
179       absl::optional<int64> memory_space = absl::nullopt) const;
180 
properties()181   const Properties& properties() const { return properties_sum_; }
property(const string & key)182   const float property(const string& key) const {
183     return GetProperty(key, properties());
184   }
185 
186   // Returns the specified per-second rate used by cost analysis.
per_second_rate(const string & key)187   const float per_second_rate(const string& key) const {
188     return GetProperty(key, per_second_rates_);
189   }
190 
191   // Return the key that is used to index into Properties for the specified
192   // input/output at the shape index.
193   static std::string GetOperandBytesAccessedKey(int64_t operand_num,
194                                                 ShapeIndex index = {});
195   static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});
196 
197  protected:
198   typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;
199 
200   // An FMA counts as two floating point operations in these analyzes.
201   static constexpr int64_t kFmaFlops = 2;
202 
203   HloCostAnalysis(const ShapeSizeFunction& shape_size,
204                   const Properties& per_second_rates);
205 
206   virtual std::unique_ptr<HloCostAnalysis> CreateNestedCostAnalysis(
207       const ShapeSizeFunction& shape_size, const Properties& per_second_rates);
208 
209   // Returns the properties computed from visiting the computation rooted at the
210   // given hlo. The cost of visited sub HLO instructions is saved to
211   // hlo_properties_, which will be used by functions such as
212   // flop_count(hlo_instruction) to return cost of a particular HLO instruction.
213   StatusOr<Properties> ProcessSubcomputation(HloComputation* computation);
214 
215   // Utility function to handle all element-wise operations.
216   Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
217 
218   // Returns the default value if the key is not present in the
219   // properties. Otherwise, returns the value that the key maps to from the
220   // properties parameter.
221   static float GetProperty(const string& key, const Properties& properties,
222                            float default_value = 0.0f);
223 
224   // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key
225   // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that
226   // the key maps to in the properties of the given hlo.
227   static float GetPropertyForHlo(const HloInstruction& hlo, const string& key,
228                                  const HloToProperties& hlo_to_properties);
229 
230   // Traverses a fusion operand to find the actual bytes accessed by the fusion
231   // node.
232   int64 FusionParameterReadBytes(const HloInstruction* hlo) const;
233 
234   // Set bytes accessed by the specified operand and shape index.
235   void SetOperandBytesAccessed(int64_t operand_num, float value);
236   void SetOperandBytesAccessed(int64_t operand_num, ShapeIndex index,
237                                float value);
238 
239   // Set bytes accessed by the output at the shape index.
240   void SetOutputBytesAccessed(float value);
241   void SetOutputBytesAccessed(ShapeIndex index, float value);
242 
243   // Function which computes the size of the top-level of a given shape (not
244   // including nested elements, if any). If null then bytes_accessed methods
245   // return an error.
246   const ShapeSizeFunction shape_size_;
247 
248   HloToProperties hlo_properties_;
249 
250   // If true, the time taken will be computed from the rates for each property
251   // and the total time will be the maximum time, which is the time of the
252   // bottleneck.
253   bool current_should_compute_bottleneck_time_;
254 
255   // The properties of the currently visited instruction. A HandleFoo method can
256   // modify these to change the default values computed in Preprocess.
257   Properties current_properties_;
258 
259   // The sum of the properties of all HLOs in the computation.
260   Properties properties_sum_;
261 
262   // How much of each property can be processed per second. E.g. if the property
263   // is bytes accessed, this is the number of bytes that can be processed per
264   // second. Is empty if no rates have been set.
265   Properties per_second_rates_;
266 
267   TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis);
268 };
269 
270 }  // namespace xla
271 
272 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
273