1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 18 19 #include "absl/types/span.h" 20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 32 // HloCostAnalysis traverses an HLO graph and calculates the amount of 33 // computations required for the graph. Each HLO instruction handler provides 34 // the computation cost of the instruction, and the values are accumulated 35 // during the traversal for the entire graph. We treat normal floating point 36 // operations separately from transcendental operations. 37 class HloCostAnalysis : public ConstDfsHloVisitor { 38 public: 39 // Each HLO is associated to a vector of properties with the indices given 40 // below. Sub-classes can add further properties. 41 // MSVC 14.0 limitation requires the consts. 42 typedef std::map<string, float> Properties; 43 static constexpr const char kFlopsKey[] = "flops"; 44 static constexpr const char kTranscendentalsKey[] = "transcendentals"; 45 static constexpr const char kBytesAccessedKey[] = "bytes accessed"; 46 static constexpr const char kOptimalSecondsKey[] = "optimal_seconds"; 47 48 // shape_size is a function which returns the size in bytes of the top-level 49 // buffer of a shape. 50 using ShapeSizeFunction = std::function<int64(const Shape&)>; 51 explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); 52 53 Status HandleElementwiseUnary(const HloInstruction* hlo) override; 54 Status HandleElementwiseBinary(const HloInstruction* hlo) override; 55 Status HandleConstant(const HloInstruction* constant) override; 56 Status HandleIota(const HloInstruction* iota) override; 57 Status HandleGetTupleElement( 58 const HloInstruction* get_tuple_element) override; 59 Status HandleSelect(const HloInstruction* hlo) override; 60 Status HandleTupleSelect(const HloInstruction* hlo) override; 61 Status HandleCompare(const HloInstruction* compare) override; 62 Status HandleClamp(const HloInstruction* clamp) override; 63 Status HandleReducePrecision(const HloInstruction* hlo) override; 64 Status HandleConcatenate(const HloInstruction* concatenate) override; 65 Status HandleCopyStart(const HloInstruction* send) override; 66 Status HandleCopyDone(const HloInstruction* send_done) override; 67 Status HandleSend(const HloInstruction* send) override; 68 Status HandleSendDone(const HloInstruction* send_done) override; 69 Status HandleRecv(const HloInstruction* recv) override; 70 Status HandleRecvDone(const HloInstruction* recv_done) override; 71 Status HandleConvert(const HloInstruction* convert) override; 72 Status HandleCopy(const HloInstruction* copy) override; 73 Status HandleDomain(const HloInstruction* domain) override; 74 Status HandleDot(const HloInstruction* dot) override; 75 Status HandleConvolution(const HloInstruction* convolution) override; 76 Status HandleFft(const HloInstruction* fft) override; 77 Status HandleTriangularSolve(const HloInstruction* hlo) override; 78 Status HandleCholesky(const HloInstruction* hlo) override; 79 Status HandleAllGather(const HloInstruction* hlo) override; 80 Status HandleAllReduce(const HloInstruction* crs) override; 81 Status HandleAllToAll(const HloInstruction* hlo) override; 82 Status HandleCollectivePermute(const HloInstruction* hlo) override; 83 Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; 84 Status HandleCollectivePermuteDone(const HloInstruction* hlo) override; 85 Status HandleReplicaId(const HloInstruction* hlo) override; 86 Status HandlePartitionId(const HloInstruction* hlo) override; 87 Status HandleInfeed(const HloInstruction* infeed) override; 88 Status HandleOutfeed(const HloInstruction* outfeed) override; 89 Status HandleRng(const HloInstruction* random) override; 90 Status HandleRngBitGenerator(const HloInstruction* random) override; 91 Status HandleRngGetAndUpdateState(const HloInstruction* random) override; 92 Status HandleReverse(const HloInstruction* reverse) override; 93 Status HandleSort(const HloInstruction* sort) override; 94 Status HandleParameter(const HloInstruction* parameter) override; 95 Status HandleReduce(const HloInstruction* reduce) override; 96 Status HandleBatchNormTraining( 97 const HloInstruction* batch_norm_training) override; 98 Status HandleBatchNormInference( 99 const HloInstruction* batch_norm_inference) override; 100 Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; 101 Status HandleFusion(const HloInstruction* fusion) override; 102 Status HandleCall(const HloInstruction* call) override; 103 Status HandleCustomCall(const HloInstruction* custom_call) override; 104 Status HandleSlice(const HloInstruction* slice) override; 105 Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; 106 Status HandleDynamicUpdateSlice( 107 const HloInstruction* dynamic_update_slice) override; 108 Status HandleTuple(const HloInstruction* tuple) override; 109 Status HandleMap(const HloInstruction* map) override; 110 Status HandleReduceWindow(const HloInstruction* reduce_window) override; 111 Status HandleSelectAndScatter(const HloInstruction* instruction) override; 112 Status HandleBitcast(const HloInstruction* bitcast) override; 113 Status HandleBroadcast(const HloInstruction* broadcast) override; 114 Status HandlePad(const HloInstruction* pad) override; 115 Status HandleReshape(const HloInstruction* reshape) override; 116 Status HandleDynamicReshape(const HloInstruction* reshape) override; 117 Status HandleAddDependency(const HloInstruction* add_dependency) override; 118 Status HandleAfterAll(const HloInstruction* token) override; 119 Status HandleTranspose(const HloInstruction* transpose) override; 120 Status HandleWhile(const HloInstruction* xla_while) override; 121 Status HandleConditional(const HloInstruction* conditional) override; 122 Status HandleGather(const HloInstruction* gather) override; 123 Status HandleScatter(const HloInstruction* scatter) override; 124 Status HandleGetDimensionSize(const HloInstruction* get_size) override; 125 Status HandleSetDimensionSize(const HloInstruction* set_size) override; 126 Status FinishVisit(const HloInstruction* root) override; 127 128 Status Preprocess(const HloInstruction* hlo) override; 129 Status Postprocess(const HloInstruction* hlo) override; 130 131 // Decorates shape_size_ by returning 0 immediately if the shape does not have 132 // a layout. 133 int64 GetShapeSize(const Shape& shape) const; 134 135 // Set the rates used to calculate the time taken by the computation. These 136 // need to be set before visiting starts. set_flops_per_second(float value)137 void set_flops_per_second(float value) { 138 per_second_rates_[kFlopsKey] = value; 139 } set_transcendentals_per_second(float value)140 void set_transcendentals_per_second(float value) { 141 per_second_rates_[kTranscendentalsKey] = value; 142 } set_bytes_per_second(float value)143 void set_bytes_per_second(float value) { 144 per_second_rates_[kBytesAccessedKey] = value; 145 } 146 147 // Returns properties for the computation. 148 float flop_count() const; 149 float transcendental_count() const; 150 float bytes_accessed() const; 151 float optimal_seconds() const; 152 153 // Returns the respective cost computed for a particular HLO instruction, or 0 154 // if the HLO was not found to have a cost in the analysis. 155 // 156 // Note that the cost for sub HLO instructions are also returned if asked. For 157 // example, body and condition of a while, fused instructions within a 158 // fusion, or the add instruction of a reduce. 159 int64 flop_count(const HloInstruction& hlo) const; 160 int64 transcendental_count(const HloInstruction& hlo) const; 161 int64 bytes_accessed(const HloInstruction& hlo) const; 162 int64 operand_bytes_accessed(const HloInstruction& hlo, int64 operand_num, 163 ShapeIndex index = {}) const; 164 int64 output_bytes_accessed(const HloInstruction& hlo, 165 ShapeIndex index = {}) const; 166 float optimal_seconds(const HloInstruction& hlo) const; 167 168 // Get bytes read/written by this HLO. If memory_space is provided, it returns 169 // the bytes read/written from/to the given memory space only. 170 int64 GetBytesRead(const HloInstruction& hlo, 171 absl::optional<int64> memory_space = absl::nullopt) const; 172 int64 GetBytesWritten( 173 const HloInstruction& hlo, 174 absl::optional<int64> memory_space = absl::nullopt) const; 175 properties()176 const Properties& properties() const { return properties_sum_; } property(const string & key)177 const float property(const string& key) const { 178 return GetProperty(key, properties()); 179 } 180 181 // Returns the specified per-second rate used by cost analysis. per_second_rate(const string & key)182 const float per_second_rate(const string& key) const { 183 return GetProperty(key, per_second_rates_); 184 } 185 186 // Return the key that is used to index into Properties for the specified 187 // input/output at the shape index. 188 static std::string GetOperandBytesAccessedKey(int64 operand_num, 189 ShapeIndex index = {}); 190 static std::string GetOutputBytesAccessedKey(ShapeIndex index = {}); 191 192 protected: 193 typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties; 194 195 // An FMA counts as two floating point operations in these analyzes. 196 static constexpr int64 kFmaFlops = 2; 197 198 HloCostAnalysis(const ShapeSizeFunction& shape_size, 199 const Properties& per_second_rates); 200 201 virtual std::unique_ptr<HloCostAnalysis> CreateNestedCostAnalysis( 202 const ShapeSizeFunction& shape_size, const Properties& per_second_rates); 203 204 // Returns the properties computed from visiting the computation rooted at the 205 // given hlo. The cost of visited sub HLO instructions is saved to 206 // hlo_properties_, which will be used by functions such as 207 // flop_count(hlo_instruction) to return cost of a particular HLO instruction. 208 StatusOr<Properties> ProcessSubcomputation(HloComputation* computation); 209 210 // Utility function to handle all element-wise operations. 211 Status HandleElementwiseOp(const HloInstruction* hlo_instruction); 212 213 // Returns the default value if the key is not present in the 214 // properties. Otherwise, returns the value that the key maps to from the 215 // properties parameter. 216 static float GetProperty(const string& key, const Properties& properties, 217 float default_value = 0.0f); 218 219 // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key 220 // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that 221 // the key maps to in the properties of the given hlo. 222 static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, 223 const HloToProperties& hlo_to_properties); 224 225 // Traverses a fusion operand to find the actual bytes accessed by the fusion 226 // node. 227 int64 FusionParameterReadBytes(const HloInstruction* hlo) const; 228 229 // Set bytes accessed by the specified operand and shape index. 230 void SetOperandBytesAccessed(int64 operand_num, float value); 231 void SetOperandBytesAccessed(int64 operand_num, ShapeIndex index, 232 float value); 233 234 // Set bytes accessed by the output at the shape index. 235 void SetOutputBytesAccessed(float value); 236 void SetOutputBytesAccessed(ShapeIndex index, float value); 237 238 // Function which computes the size of the top-level of a given shape (not 239 // including nested elements, if any). If null then bytes_accessed methods 240 // return an error. 241 const ShapeSizeFunction shape_size_; 242 243 HloToProperties hlo_properties_; 244 245 // If true, the time taken will be computed from the rates for each property 246 // and the total time will be the maximum time, which is the time of the 247 // bottleneck. 248 bool current_should_compute_bottleneck_time_; 249 250 // The properties of the currently visited instruction. A HandleFoo method can 251 // modify these to change the default values computed in Preprocess. 252 Properties current_properties_; 253 254 // The sum of the properties of all HLOs in the computation. 255 Properties properties_sum_; 256 257 // How much of each property can be processed per second. E.g. if the property 258 // is bytes accessed, this is the number of bytes that can be processed per 259 // second. Is empty if no rates have been set. 260 Properties per_second_rates_; 261 262 TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); 263 }; 264 265 } // namespace xla 266 267 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 268