1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 18 19 #define _USE_MATH_DEFINES 20 21 #include <functional> 22 #include <memory> 23 24 #include "absl/container/node_hash_map.h" 25 #include "absl/memory/memory.h" 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/array2d.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 30 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" 31 #include "tensorflow/compiler/xla/service/hlo_computation.h" 32 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 33 #include "tensorflow/compiler/xla/service/hlo_module.h" 34 #include "tensorflow/compiler/xla/service/shape_inference.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/macros.h" 39 40 namespace xla { 41 42 // Responsible for evaluating HLO and obtain literal as the evaluation results. 43 // 44 // This class is not thread-safe. 45 class HloEvaluator : public DfsHloVisitorWithDefault { 46 public: 47 // Only evaluate up to max_loop_iterations per while-loop execution if 48 // specified. 49 explicit HloEvaluator(int64_t max_loop_iterations = -1); 50 51 // Evaluates an HLO module and an array of pointers to literals. Returns the 52 // evaluated result as a literal if successful. 53 // 54 // Precondition: The indices of arg_literals correspond to the parameter 55 // numbers of the HLO parameters in the computation. See comment below for an 56 // example. 57 // 58 // (Dummy template arg is to reduce the overloading priority of one overload 59 // so that Evaluate(module, {}) resolves unambiguously.) Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)60 StatusOr<Literal> Evaluate(const HloModule& module, 61 absl::Span<const Literal* const> arg_literals) { 62 return Evaluate(*module.entry_computation(), arg_literals); 63 } 64 template <typename Dummy = void> Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)65 StatusOr<Literal> Evaluate(const HloModule& module, 66 absl::Span<const Literal> arg_literals) { 67 return Evaluate(*module.entry_computation(), arg_literals); 68 } 69 70 // Evaluates an HLO computation and an array of pointers to literals. 71 // Returns the evaluated result as a literal if successful. 72 // Precondition: The indices of arg_literals correspond to the parameter 73 // numbers of the HLO parameters in the computation. For e.g., consider the 74 // following graph: 75 // 76 // * 77 // / \ 78 // + Parameter1 79 // / \ 80 // / \ 81 // Parameter0 Constant 82 // 83 // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number 84 // 1 in this computation. The input literals array will then have its first 85 // literal map to Parameter0 and the second map to Parameter1. 86 // 87 // (Dummy template arg is to reduce the overloading priority of one overload 88 // so that Evaluate(module, {}) resolves unambiguously.) 89 StatusOr<Literal> Evaluate(const HloComputation& computation, 90 absl::Span<const Literal* const> arg_literals); 91 template <typename Dummy = void> Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)92 StatusOr<Literal> Evaluate(const HloComputation& computation, 93 absl::Span<const Literal> arg_literals) { 94 std::vector<const Literal*> arg_literal_ptrs; 95 for (const auto& l : arg_literals) { 96 arg_literal_ptrs.push_back(&l); 97 } 98 return Evaluate(computation, arg_literal_ptrs); 99 } 100 101 // Gets the value of running a single HLO instruction. 102 // 103 // All of the operands to this instruction must be constants. 104 StatusOr<Literal> Evaluate(HloInstruction* instruction); 105 106 // Same as Evaluate, except returning false on error and accepts an output 107 // pointer. 108 bool TryEvaluate(HloInstruction* instruction, Literal* result); 109 110 // Evaluates a single HLO instruction, substituting the given literals for 111 // some of the instruction's operands. 112 // 113 // For example, given instruction = op(A, B, C) and the map 114 // {A = x, C = y}, this evaluates op(x, B, y). 115 StatusOr<Literal> EvaluateWithSubstitutions( 116 const HloInstruction* instruction, 117 const std::unordered_map<const HloInstruction*, const Literal*>& 118 substitutions); 119 120 StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode, 121 const Literal& lhs, 122 const Literal& rhs); 123 124 StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode, 125 const Literal& operand); 126 127 StatusOr<Literal> EvaluateElementwiseTernaryOp(HloOpcode opcode, 128 const Literal& lhs, 129 const Literal& rhs, 130 const Literal& ehs); 131 132 StatusOr<Literal> EvaluateElementwiseCompareOp(ComparisonDirection direction, 133 const Literal& lhs, 134 const Literal& rhs); 135 136 StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers, 137 const PrecisionConfig& precision_config, 138 const Literal& lhs, const Literal& rhs); 139 set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)140 void set_dynamic_dimension_inference( 141 DynamicDimensionInference* dynamic_dimension_inference) { 142 dynamic_dimension_inference_ = dynamic_dimension_inference; 143 } 144 dynamic_dimension_inference()145 DynamicDimensionInference* dynamic_dimension_inference() { 146 return dynamic_dimension_inference_; 147 } 148 149 // Enable the fast path for certain operations like dot or convolution. set_use_fast_path(bool value)150 void set_use_fast_path(bool value) { use_fast_path_ = value; } 151 152 // Handles evaluation of a custom-call op. 153 // Operand literals are provided in |operands| and implementations must 154 // populate |output| before returning. 155 using CustomCallHandler = std::function<StatusOr<Literal>( 156 HloInstruction* custom_call, absl::Span<const Literal*> operands)>; 157 158 // Sets a handler that is called during evaluation for custom-call ops. 159 // If no handler is defined the default error behavior will occur. The handler 160 // will be provided evaluated literals for all operands and is expected to 161 // return an output literal of the appropriate shape. set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)162 void set_custom_call_handler( 163 std::function<StatusOr<Literal>(HloInstruction* custom_call, 164 absl::Span<const Literal*> operands)> 165 handler) { 166 custom_call_handler_ = std::move(handler); 167 } 168 169 // Returns the result of a matrix multiply `lhs x rhs`. 170 static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D( 171 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs); 172 static std::unique_ptr<Array2D<float>> MatmulArray2D( 173 const Array2D<float>& lhs, const Array2D<float>& rhs); 174 static std::unique_ptr<Array2D<double>> MatmulArray2D( 175 const Array2D<double>& lhs, const Array2D<double>& rhs); 176 static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D( 177 const Array2D<std::complex<float>>& lhs, 178 const Array2D<std::complex<float>>& rhs); 179 static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D( 180 const Array2D<std::complex<double>>& lhs, 181 const Array2D<std::complex<double>>& rhs); 182 static std::unique_ptr<Array2D<int32>> MatmulArray2D( 183 const Array2D<int32>& lhs, const Array2D<int32>& rhs); 184 185 protected: 186 // Make HloEvaluatorTypedVisitor a friend because it is logically part of this 187 // class. 188 // 189 // A straightforward implementation would be to make it a nested class 190 // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor 191 // lives as a separate class with its own header because its template gets 192 // instantiated many times and we want to use extern templates to shard out 193 // the compilation of those instantiations across multiple cc files. 194 template <typename ReturnT, typename ElementwiseT> 195 friend class HloEvaluatorTypedVisitor; 196 197 // Wraps around instruction handling to infer types before dispatching to 198 // the corresponding typed Visitor. DefaultAction(HloInstruction * hlo)199 Status DefaultAction(HloInstruction* hlo) override { 200 return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); 201 } 202 203 Status Preprocess(HloInstruction* hlo) override; 204 205 Status Postprocess(HloInstruction* hlo) override; 206 207 // Operations that are type-agnostic or always return a specific type, such as 208 // HandleIsFinite where boolean is always returned. 209 // 210 Status HandleBitcast(HloInstruction* bitcast) override; 211 212 Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; 213 214 Status HandleSetDimensionSize(HloInstruction* set_dimension_size) override; 215 216 Status HandleParameter(HloInstruction* parameter) override; 217 218 Status HandleConstant(HloInstruction* constant) override; 219 220 Status HandleConcatenate(HloInstruction* concatenate) override; 221 222 Status HandleReshape(HloInstruction* reshape) override; 223 224 Status HandleTranspose(HloInstruction* transpose) override; 225 226 Status HandleIsFinite(HloInstruction* is_finite) override; 227 228 Status HandleCompare(HloInstruction* compare) override; 229 230 Status HandleTuple(HloInstruction* tuple) override; 231 232 Status HandleFft(HloInstruction* fft) override; 233 234 Status HandleGather(HloInstruction* gather) override; 235 236 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 237 238 Status HandleCopy(HloInstruction* copy) override; 239 240 Status HandleCopyStart(HloInstruction* copy_start) override; 241 242 Status HandleCopyDone(HloInstruction* copy_done) override; 243 244 Status HandleConditional(HloInstruction* conditional) override; 245 246 Status HandleCall(HloInstruction* call) override; 247 248 Status HandleFusion(HloInstruction* fusion) override; 249 250 Status HandleWhile(HloInstruction* while_hlo) override; 251 252 Status HandleSelect(HloInstruction* select) override; 253 254 Status HandleTupleSelect(HloInstruction* tuple_select) override; 255 256 Status HandleBroadcast(HloInstruction* broadcast) override; 257 258 Status HandleAfterAll(HloInstruction* after_all) override; 259 260 Status HandleAddDependency(HloInstruction* add_dependency) override; 261 262 Status HandleSort(HloInstruction* sort) override; 263 264 Status HandleReal(HloInstruction* real) override; 265 266 Status HandleImag(HloInstruction* imag) override; 267 268 Status HandleComplex(HloInstruction* complex) override; 269 270 Status HandleReduce(HloInstruction* reduce) override; 271 272 Status HandleReduceWindow(HloInstruction* hlo) override; 273 274 Status HandleCustomCall(HloInstruction* custom_call) override; 275 276 // Unsupported HLOs, note some of them (such as BatchNorm*) are typically 277 // expanded in a semantic-preserving way into other HLOs by adding expansion 278 // HLO pass to the HLO optimization pass during compilation, which can then be 279 // handled by the evaluator. HandleBatchNormGrad(HloInstruction * batch_norm_grad)280 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { 281 return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); 282 }; HandleBatchNormInference(HloInstruction * batch_norm_inference)283 Status HandleBatchNormInference( 284 HloInstruction* batch_norm_inference) override { 285 return Unimplemented( 286 "BatchNormInference HLO is unsupported by the evaluator."); 287 }; HandleBatchNormTraining(HloInstruction * batch_norm_training)288 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { 289 return Unimplemented( 290 "BatchNormTraining HLO is unsupported by the evaluator."); 291 }; HandleInfeed(HloInstruction * infeed)292 Status HandleInfeed(HloInstruction* infeed) override { 293 return Unimplemented("Infeed HLO is unsupported by the evaluator."); 294 }; HandleOutfeed(HloInstruction * outfeed)295 Status HandleOutfeed(HloInstruction* outfeed) override { 296 return Unimplemented("Outfeed HLO is unsupported by the evaluator."); 297 }; 298 299 // Returns the already-evaluated literal result for the instruction. 300 // 301 // A Constant instruction is considered evaluated and its literal will be 302 // returned directly without looking up the cache. 303 // 304 // Similarly, a Parameter instruction is considered evaluated and its literal 305 // is looked up in arg_literals. 306 // 307 // Crash with log if the given instruction has not been evaluated previously. GetEvaluatedLiteralFor(const HloInstruction * hlo)308 const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { 309 if (hlo->IsConstant()) { 310 return hlo->literal(); 311 } 312 if (hlo->opcode() == HloOpcode::kParameter) { 313 return *arg_literals_.at(hlo->parameter_number()); 314 } 315 auto it = evaluated_.find(hlo); 316 CHECK(it != evaluated_.end()) 317 << "could not find evaluated value for: " << hlo->ToString(); 318 return it->second; 319 } 320 321 // Tracks the HLO instruction and its evaluated literal result. 322 // 323 // Parameters and constants aren't stored here, see implementation of 324 // GetEvaluatedLiteralFor. 325 // 326 // TODO(b/35950897): have better memory management here to free instructions 327 // that are no longer a parent for any other subsequent instruction in 328 // post-ordering. 329 // 330 // Must be cleared for each evaluation. 331 // 332 // Storing Literal in place requires the container to have pointer stability 333 // so we cannot use flat_hash_map any more. 334 absl::node_hash_map<const HloInstruction*, Literal> evaluated_; 335 336 // Use fast path that uses eigen in the evaluator. 337 bool use_fast_path_ = false; 338 339 private: 340 template <typename ReturnT, typename NativeT> ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)341 static StatusOr<Literal> ElementWiseUnaryOpImpl( 342 HloInstruction* instruction, 343 const std::function<ReturnT(NativeT)>& unary_op, 344 const Literal& operand_literal) { 345 const auto shape = instruction->shape(); 346 const auto* operand = instruction->operand(0); 347 TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); 348 349 Literal result(shape); 350 TF_RETURN_IF_ERROR( 351 result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { 352 return unary_op(operand_literal.Get<NativeT>(multi_index)); 353 })); 354 return std::move(result); 355 } 356 357 // Map from a primitive type to its associated (templated) DfsHloVisitor. 358 std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE]; 359 360 // Caches pointers to input literals, assuming they are in post-order. 361 // Literals are not owned by this class, and they must outlive the lifetime of 362 // each invocation to the Evaluate* method. 363 // Must be cleared for each evaluation. 364 std::vector<const Literal*> arg_literals_; 365 366 // Max loop iterations to execute with no maximum if negative. 367 int64 max_loop_iterations_ = 0; 368 369 // Module-level seed handle. 370 uint64 seed_ = 0; 371 // RNG engine. 372 std::minstd_rand0 engine_; 373 374 // DynamicDimensionInference is used to evaluate GetDimensionSize, which 375 // returns the dynamic dimension size of its operand. 376 DynamicDimensionInference* dynamic_dimension_inference_ = nullptr; 377 378 // Optional handler for custom_call ops. 379 std::function<StatusOr<Literal>(HloInstruction* custom_call, 380 absl::Span<const Literal*> operands)> 381 custom_call_handler_; 382 383 TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); 384 }; 385 386 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs, 387 const Array2D<float>& rhs); 388 } // namespace xla 389 390 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 391