1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 18 19 #include "absl/types/span.h" 20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 32 // HloCostAnalysis traverses an HLO graph and calculates the amount of 33 // computations required for the graph. Each HLO instruction handler provides 34 // the computation cost of the instruction, and the values are accumulated 35 // during the traversal for the entire graph. We treat normal floating point 36 // operations separately from transcendental operations. 37 class HloCostAnalysis : public ConstDfsHloVisitor { 38 public: 39 // Each HLO is associated to a vector of properties with the indices given 40 // below. Sub-classes can add further properties. 41 // MSVC 14.0 limitation requires the consts. 42 typedef std::map<string, float> Properties; 43 static constexpr const char kFlopsKey[] = "flops"; 44 static constexpr const char kTranscendentalsKey[] = "transcendentals"; 45 static constexpr const char kBytesAccessedKey[] = "bytes accessed"; 46 static constexpr const char kOptimalSecondsKey[] = "optimal_seconds"; 47 48 // shape_size is a function which returns the size in bytes of the top-level 49 // buffer of a shape. 50 using ShapeSizeFunction = std::function<int64(const Shape&)>; 51 explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); 52 53 Status HandleElementwiseUnary(const HloInstruction* hlo) override; 54 Status HandleElementwiseBinary(const HloInstruction* hlo) override; 55 Status HandleConstant(const HloInstruction* constant) override; 56 Status HandleIota(const HloInstruction* iota) override; 57 Status HandleGetTupleElement( 58 const HloInstruction* get_tuple_element) override; 59 Status HandleSelect(const HloInstruction* hlo) override; 60 Status HandleTupleSelect(const HloInstruction* hlo) override; 61 Status HandleCompare(const HloInstruction* compare) override; 62 Status HandleClamp(const HloInstruction* clamp) override; 63 Status HandleReducePrecision(const HloInstruction* hlo) override; 64 Status HandleConcatenate(const HloInstruction* concatenate) override; 65 Status HandleCopyStart(const HloInstruction* send) override; 66 Status HandleCopyDone(const HloInstruction* send_done) override; 67 Status HandleSend(const HloInstruction* send) override; 68 Status HandleSendDone(const HloInstruction* send_done) override; 69 Status HandleRecv(const HloInstruction* recv) override; 70 Status HandleRecvDone(const HloInstruction* recv_done) override; 71 Status HandleConvert(const HloInstruction* convert) override; 72 Status HandleCopy(const HloInstruction* copy) override; 73 Status HandleDomain(const HloInstruction* domain) override; 74 Status HandleDot(const HloInstruction* dot) override; 75 Status HandleConvolution(const HloInstruction* convolution) override; 76 Status HandleFft(const HloInstruction* fft) override; 77 Status HandleTriangularSolve(const HloInstruction* hlo) override; 78 Status HandleCholesky(const HloInstruction* hlo) override; 79 Status HandleAllGather(const HloInstruction* hlo) override; 80 Status HandleAllGatherStart(const HloInstruction* hlo) override; 81 Status HandleAllGatherDone(const HloInstruction* hlo) override; 82 Status HandleAllReduce(const HloInstruction* crs) override; 83 Status HandleReduceScatter(const HloInstruction* hlo) override; 84 Status HandleAllReduceStart(const HloInstruction* hlo) override; 85 Status HandleAllReduceDone(const HloInstruction* hlo) override; 86 Status HandleAllToAll(const HloInstruction* hlo) override; 87 Status HandleCollectivePermute(const HloInstruction* hlo) override; 88 Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; 89 Status HandleCollectivePermuteDone(const HloInstruction* hlo) override; 90 Status HandleReplicaId(const HloInstruction* hlo) override; 91 Status HandlePartitionId(const HloInstruction* hlo) override; 92 Status HandleInfeed(const HloInstruction* infeed) override; 93 Status HandleOutfeed(const HloInstruction* outfeed) override; 94 Status HandleRng(const HloInstruction* random) override; 95 Status HandleRngBitGenerator(const HloInstruction* random) override; 96 Status HandleRngGetAndUpdateState(const HloInstruction* random) override; 97 Status HandleReverse(const HloInstruction* reverse) override; 98 Status HandleSort(const HloInstruction* sort) override; 99 Status HandleParameter(const HloInstruction* parameter) override; 100 Status HandleReduce(const HloInstruction* reduce) override; 101 Status HandleBatchNormTraining( 102 const HloInstruction* batch_norm_training) override; 103 Status HandleBatchNormInference( 104 const HloInstruction* batch_norm_inference) override; 105 Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; 106 Status HandleFusion(const HloInstruction* fusion) override; 107 Status HandleCall(const HloInstruction* call) override; 108 Status HandleCustomCall(const HloInstruction* custom_call) override; 109 Status HandleSlice(const HloInstruction* slice) override; 110 Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; 111 Status HandleDynamicUpdateSlice( 112 const HloInstruction* dynamic_update_slice) override; 113 Status HandleTuple(const HloInstruction* tuple) override; 114 Status HandleMap(const HloInstruction* map) override; 115 Status HandleReduceWindow(const HloInstruction* reduce_window) override; 116 Status HandleSelectAndScatter(const HloInstruction* instruction) override; 117 Status HandleBitcast(const HloInstruction* bitcast) override; 118 Status HandleBroadcast(const HloInstruction* broadcast) override; 119 Status HandlePad(const HloInstruction* pad) override; 120 Status HandleReshape(const HloInstruction* reshape) override; 121 Status HandleDynamicReshape(const HloInstruction* reshape) override; 122 Status HandleAddDependency(const HloInstruction* add_dependency) override; 123 Status HandleAfterAll(const HloInstruction* token) override; 124 Status HandleTranspose(const HloInstruction* transpose) override; 125 Status HandleWhile(const HloInstruction* xla_while) override; 126 Status HandleConditional(const HloInstruction* conditional) override; 127 Status HandleGather(const HloInstruction* gather) override; 128 Status HandleScatter(const HloInstruction* scatter) override; 129 Status HandleGetDimensionSize(const HloInstruction* get_size) override; 130 Status HandleSetDimensionSize(const HloInstruction* set_size) override; 131 Status FinishVisit(const HloInstruction* root) override; 132 133 Status Preprocess(const HloInstruction* hlo) override; 134 Status Postprocess(const HloInstruction* hlo) override; 135 136 // Decorates shape_size_ by returning 0 immediately if the shape does not have 137 // a layout. 138 int64 GetShapeSize(const Shape& shape) const; 139 140 // Set the rates used to calculate the time taken by the computation. These 141 // need to be set before visiting starts. set_flops_per_second(float value)142 void set_flops_per_second(float value) { 143 per_second_rates_[kFlopsKey] = value; 144 } set_transcendentals_per_second(float value)145 void set_transcendentals_per_second(float value) { 146 per_second_rates_[kTranscendentalsKey] = value; 147 } set_bytes_per_second(float value)148 void set_bytes_per_second(float value) { 149 per_second_rates_[kBytesAccessedKey] = value; 150 } 151 152 // Returns properties for the computation. 153 float flop_count() const; 154 float transcendental_count() const; 155 float bytes_accessed() const; 156 float optimal_seconds() const; 157 158 // Returns the respective cost computed for a particular HLO instruction, or 0 159 // if the HLO was not found to have a cost in the analysis. 160 // 161 // Note that the cost for sub HLO instructions are also returned if asked. For 162 // example, body and condition of a while, fused instructions within a 163 // fusion, or the add instruction of a reduce. 164 int64 flop_count(const HloInstruction& hlo) const; 165 int64 transcendental_count(const HloInstruction& hlo) const; 166 int64 bytes_accessed(const HloInstruction& hlo) const; 167 int64 operand_bytes_accessed(const HloInstruction& hlo, int64_t operand_num, 168 ShapeIndex index = {}) const; 169 int64 output_bytes_accessed(const HloInstruction& hlo, 170 ShapeIndex index = {}) const; 171 float optimal_seconds(const HloInstruction& hlo) const; 172 173 // Get bytes read/written by this HLO. If memory_space is provided, it returns 174 // the bytes read/written from/to the given memory space only. 175 int64 GetBytesRead(const HloInstruction& hlo, 176 absl::optional<int64> memory_space = absl::nullopt) const; 177 int64 GetBytesWritten( 178 const HloInstruction& hlo, 179 absl::optional<int64> memory_space = absl::nullopt) const; 180 properties()181 const Properties& properties() const { return properties_sum_; } property(const string & key)182 const float property(const string& key) const { 183 return GetProperty(key, properties()); 184 } 185 186 // Returns the specified per-second rate used by cost analysis. per_second_rate(const string & key)187 const float per_second_rate(const string& key) const { 188 return GetProperty(key, per_second_rates_); 189 } 190 191 // Return the key that is used to index into Properties for the specified 192 // input/output at the shape index. 193 static std::string GetOperandBytesAccessedKey(int64_t operand_num, 194 ShapeIndex index = {}); 195 static std::string GetOutputBytesAccessedKey(ShapeIndex index = {}); 196 197 protected: 198 typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties; 199 200 // An FMA counts as two floating point operations in these analyzes. 201 static constexpr int64_t kFmaFlops = 2; 202 203 HloCostAnalysis(const ShapeSizeFunction& shape_size, 204 const Properties& per_second_rates); 205 206 virtual std::unique_ptr<HloCostAnalysis> CreateNestedCostAnalysis( 207 const ShapeSizeFunction& shape_size, const Properties& per_second_rates); 208 209 // Returns the properties computed from visiting the computation rooted at the 210 // given hlo. The cost of visited sub HLO instructions is saved to 211 // hlo_properties_, which will be used by functions such as 212 // flop_count(hlo_instruction) to return cost of a particular HLO instruction. 213 StatusOr<Properties> ProcessSubcomputation(HloComputation* computation); 214 215 // Utility function to handle all element-wise operations. 216 Status HandleElementwiseOp(const HloInstruction* hlo_instruction); 217 218 // Returns the default value if the key is not present in the 219 // properties. Otherwise, returns the value that the key maps to from the 220 // properties parameter. 221 static float GetProperty(const string& key, const Properties& properties, 222 float default_value = 0.0f); 223 224 // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key 225 // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that 226 // the key maps to in the properties of the given hlo. 227 static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, 228 const HloToProperties& hlo_to_properties); 229 230 // Traverses a fusion operand to find the actual bytes accessed by the fusion 231 // node. 232 int64 FusionParameterReadBytes(const HloInstruction* hlo) const; 233 234 // Set bytes accessed by the specified operand and shape index. 235 void SetOperandBytesAccessed(int64_t operand_num, float value); 236 void SetOperandBytesAccessed(int64_t operand_num, ShapeIndex index, 237 float value); 238 239 // Set bytes accessed by the output at the shape index. 240 void SetOutputBytesAccessed(float value); 241 void SetOutputBytesAccessed(ShapeIndex index, float value); 242 243 // Function which computes the size of the top-level of a given shape (not 244 // including nested elements, if any). If null then bytes_accessed methods 245 // return an error. 246 const ShapeSizeFunction shape_size_; 247 248 HloToProperties hlo_properties_; 249 250 // If true, the time taken will be computed from the rates for each property 251 // and the total time will be the maximum time, which is the time of the 252 // bottleneck. 253 bool current_should_compute_bottleneck_time_; 254 255 // The properties of the currently visited instruction. A HandleFoo method can 256 // modify these to change the default values computed in Preprocess. 257 Properties current_properties_; 258 259 // The sum of the properties of all HLOs in the computation. 260 Properties properties_sum_; 261 262 // How much of each property can be processed per second. E.g. if the property 263 // is bytes accessed, this is the number of bytes that can be processed per 264 // second. Is empty if no rates have been set. 265 Properties per_second_rates_; 266 267 TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); 268 }; 269 270 } // namespace xla 271 272 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 273