1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 18 19 #include "absl/types/span.h" 20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 32 // HloCostAnalysis traverses an HLO graph and calculates the amount of 33 // computations required for the graph. Each HLO instruction handler provides 34 // the computation cost of the instruction, and the values are accumulated 35 // during the traversal for the entire graph. We treat normal floating point 36 // operations separately from transcendental operations. 37 class HloCostAnalysis : public ConstDfsHloVisitor { 38 public: 39 // Each HLO is associated to a vector of properties with the indices given 40 // below. Sub-classes can add further properties. 41 // MSVC 14.0 limitation requires the consts. 42 typedef std::map<string, float> Properties; 43 static constexpr const char kFlopsKey[] = "flops"; 44 static constexpr const char kTranscendentalsKey[] = "transcendentals"; 45 static constexpr const char kBytesAccessedKey[] = "bytes accessed"; 46 static constexpr const char kOptimalSecondsKey[] = "optimal_seconds"; 47 48 // shape_size is a function which returns the size in bytes of the top-level 49 // buffer of a shape. 50 using ShapeSizeFunction = std::function<int64(const Shape&)>; 51 explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); 52 53 Status HandleElementwiseUnary(const HloInstruction* hlo) override; 54 Status HandleElementwiseBinary(const HloInstruction* hlo) override; 55 Status HandleConstant(const HloInstruction* constant) override; 56 Status HandleIota(const HloInstruction* iota) override; 57 Status HandleGetTupleElement( 58 const HloInstruction* get_tuple_element) override; 59 Status HandleSelect(const HloInstruction* hlo) override; 60 Status HandleTupleSelect(const HloInstruction* hlo) override; 61 Status HandleCompare(const HloInstruction* compare) override; 62 Status HandleClamp(const HloInstruction* clamp) override; 63 Status HandleReducePrecision(const HloInstruction* hlo) override; 64 Status HandleConcatenate(const HloInstruction* concatenate) override; 65 Status HandleCopyStart(const HloInstruction* send) override; 66 Status HandleCopyDone(const HloInstruction* send_done) override; 67 Status HandleSend(const HloInstruction* send) override; 68 Status HandleSendDone(const HloInstruction* send_done) override; 69 Status HandleRecv(const HloInstruction* recv) override; 70 Status HandleRecvDone(const HloInstruction* recv_done) override; 71 Status HandleConvert(const HloInstruction* convert) override; 72 Status HandleCopy(const HloInstruction* copy) override; 73 Status HandleDomain(const HloInstruction* domain) override; 74 Status HandleDot(const HloInstruction* dot) override; 75 Status HandleConvolution(const HloInstruction* convolution) override; 76 Status HandleFft(const HloInstruction* fft) override; 77 Status HandleTriangularSolve(const HloInstruction* hlo) override; 78 Status HandleCholesky(const HloInstruction* hlo) override; 79 Status HandleAllReduce(const HloInstruction* crs) override; 80 Status HandleAllToAll(const HloInstruction* hlo) override; 81 Status HandleCollectivePermute(const HloInstruction* hlo) override; 82 Status HandleReplicaId(const HloInstruction* hlo) override; 83 Status HandlePartitionId(const HloInstruction* hlo) override; 84 Status HandleInfeed(const HloInstruction* infeed) override; 85 Status HandleOutfeed(const HloInstruction* outfeed) override; 86 Status HandleRng(const HloInstruction* random) override; 87 Status HandleRngGetAndUpdateState(const HloInstruction* random) override; 88 Status HandleReverse(const HloInstruction* reverse) override; 89 Status HandleSort(const HloInstruction* sort) override; 90 Status HandleParameter(const HloInstruction* parameter) override; 91 Status HandleReduce(const HloInstruction* reduce) override; 92 Status HandleBatchNormTraining( 93 const HloInstruction* batch_norm_training) override; 94 Status HandleBatchNormInference( 95 const HloInstruction* batch_norm_inference) override; 96 Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; 97 Status HandleFusion(const HloInstruction* fusion) override; 98 Status HandleCall(const HloInstruction* call) override; 99 Status HandleCustomCall(const HloInstruction* custom_call) override; 100 Status HandleSlice(const HloInstruction* slice) override; 101 Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; 102 Status HandleDynamicUpdateSlice( 103 const HloInstruction* dynamic_update_slice) override; 104 Status HandleTuple(const HloInstruction* tuple) override; 105 Status HandleMap(const HloInstruction* map) override; 106 Status HandleReduceWindow(const HloInstruction* reduce_window) override; 107 Status HandleSelectAndScatter(const HloInstruction* instruction) override; 108 Status HandleBitcast(const HloInstruction* bitcast) override; 109 Status HandleBroadcast(const HloInstruction* broadcast) override; 110 Status HandlePad(const HloInstruction* pad) override; 111 Status HandleReshape(const HloInstruction* reshape) override; 112 Status HandleAddDependency(const HloInstruction* add_dependency) override; 113 Status HandleAfterAll(const HloInstruction* token) override; 114 Status HandleTranspose(const HloInstruction* transpose) override; 115 Status HandleWhile(const HloInstruction* xla_while) override; 116 Status HandleConditional(const HloInstruction* conditional) override; 117 Status HandleGather(const HloInstruction* gather) override; 118 Status HandleScatter(const HloInstruction* scatter) override; 119 Status HandleGetDimensionSize(const HloInstruction* get_size) override; 120 Status HandleSetDimensionSize(const HloInstruction* set_size) override; 121 Status FinishVisit(const HloInstruction* root) override; 122 123 Status Preprocess(const HloInstruction* hlo) override; 124 Status Postprocess(const HloInstruction* hlo) override; 125 126 // Decorates shape_size_ by returning 0 immediately if the shape does not have 127 // a layout. 128 int64 GetShapeSize(const Shape& shape) const; 129 130 // Set the rates used to calculate the time taken by the computation. These 131 // need to be set before visiting starts. set_flops_per_second(float value)132 void set_flops_per_second(float value) { 133 per_second_rates_[kFlopsKey] = value; 134 } set_transcendentals_per_second(float value)135 void set_transcendentals_per_second(float value) { 136 per_second_rates_[kTranscendentalsKey] = value; 137 } set_bytes_per_second(float value)138 void set_bytes_per_second(float value) { 139 per_second_rates_[kBytesAccessedKey] = value; 140 } 141 142 // Returns properties for the computation. 143 float flop_count() const; 144 float transcendental_count() const; 145 float bytes_accessed() const; 146 float optimal_seconds() const; 147 148 // Returns the respective cost computed for a particular HLO instruction, or 0 149 // if the HLO was not found to have a cost in the analysis. 150 // 151 // Note that the cost for sub HLO instructions are also returned if asked. For 152 // example, body and condition of a while, fused instructions within a 153 // fusion, or the add instruction of a reduce. 154 int64 flop_count(const HloInstruction& hlo) const; 155 int64 transcendental_count(const HloInstruction& hlo) const; 156 int64 bytes_accessed(const HloInstruction& hlo) const; 157 int64 operand_bytes_accessed(const HloInstruction& hlo, int64 operand_num, 158 ShapeIndex index = {}) const; 159 int64 output_bytes_accessed(const HloInstruction& hlo, 160 ShapeIndex index = {}) const; 161 float optimal_seconds(const HloInstruction& hlo) const; 162 properties()163 const Properties& properties() const { return properties_sum_; } property(const string & key)164 const float property(const string& key) const { 165 return GetProperty(key, properties()); 166 } 167 168 // Returns the specified per-second rate used by cost analysis. per_second_rate(const string & key)169 const float per_second_rate(const string& key) const { 170 return GetProperty(key, per_second_rates_); 171 } 172 173 protected: 174 typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties; 175 176 // An FMA counts as two floating point operations in these analyzes. 177 static constexpr int64 kFmaFlops = 2; 178 179 HloCostAnalysis(const ShapeSizeFunction& shape_size, 180 const Properties& per_second_rates); 181 182 virtual std::unique_ptr<HloCostAnalysis> CreateNestedCostAnalysis( 183 const ShapeSizeFunction& shape_size, const Properties& per_second_rates); 184 185 // Returns the properties computed from visiting the computation rooted at the 186 // given hlo. The cost of visited sub HLO instructions is saved to 187 // hlo_properties_, which will be used by functions such as 188 // flop_count(hlo_instruction) to return cost of a particular HLO instruction. 189 StatusOr<Properties> ProcessSubcomputation(HloComputation* computation); 190 191 // Utility function to handle all element-wise operations. 192 Status HandleElementwiseOp(const HloInstruction* hlo_instruction); 193 194 // Returns the default value if the key is not present in the 195 // properties. Otherwise, returns the value that the key maps to from the 196 // properties parameter. 197 static float GetProperty(const string& key, const Properties& properties, 198 float default_value = 0.0f); 199 200 // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key 201 // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that 202 // the key maps to in the properties of the given hlo. 203 static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, 204 const HloToProperties& hlo_to_properties); 205 206 // Traverses a fusion operand to find the actual bytes accessed by the fusion 207 // node. 208 int64 FusionParameterReadBytes(const HloInstruction* hlo) const; 209 210 // Set bytes accessed by the specified operand and shape index. 211 void SetOperandBytesAccessed(int64 operand_num, float value); 212 void SetOperandBytesAccessed(int64 operand_num, ShapeIndex index, 213 float value); 214 215 // Set bytes accessed by the output at the shape index. 216 void SetOutputBytesAccessed(float value); 217 void SetOutputBytesAccessed(ShapeIndex index, float value); 218 219 // Return the key that is used to index into Properties for the specified 220 // input/output at the shape index. 221 static std::string GetOperandBytesAccessedKey(int64 operand_num, 222 ShapeIndex index = {}); 223 static std::string GetOutputBytesAccessedKey(ShapeIndex index = {}); 224 225 // Function which computes the size of the top-level of a given shape (not 226 // including nested elements, if any). If null then bytes_accessed methods 227 // return an error. 228 const ShapeSizeFunction shape_size_; 229 230 HloToProperties hlo_properties_; 231 232 // If true, the time taken will be computed from the rates for each property 233 // and the total time will be the maximum time, which is the time of the 234 // bottleneck. 235 bool current_should_compute_bottleneck_time_; 236 237 // The properties of the currently visited instruction. A HandleFoo method can 238 // modify these to change the default values computed in Preprocess. 239 Properties current_properties_; 240 241 // The sum of the properties of all HLOs in the computation. 242 Properties properties_sum_; 243 244 // How much of each property can be processed per second. E.g. if the property 245 // is bytes accessed, this is the number of bytes that can be processed per 246 // second. Is empty if no rates have been set. 247 Properties per_second_rates_; 248 249 TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); 250 }; 251 252 } // namespace xla 253 254 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 255