1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 18 19 #define _USE_MATH_DEFINES 20 21 #include <functional> 22 #include <memory> 23 #include <optional> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "absl/container/node_hash_map.h" 27 #include "absl/types/span.h" 28 #include "tensorflow/compiler/xla/array2d.h" 29 #include "tensorflow/compiler/xla/literal.h" 30 #include "tensorflow/compiler/xla/literal_util.h" 31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 32 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" 33 #include "tensorflow/compiler/xla/service/hlo_computation.h" 34 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/shape_inference.h" 37 #include "tensorflow/compiler/xla/shape_util.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/util.h" 40 #include "tensorflow/compiler/xla/xla_data.pb.h" 41 42 namespace xla { 43 44 // Represents a parsed static while loop. We normalize the loop representation 45 // so that it starts from the induction_var_init_value and increments by 46 // step_size until it exceeds or goes below loop_bound. 47 struct ParsedStaticWhileLoop { 48 // The number of iterations to be executed. 49 int64_t trip_count = -1; 50 // The tuple index of the induction variable in the while argument tuple. 51 int64_t induction_var_index = -1; 52 // The induction variable's initial value. 53 int64_t induction_var_init_value = -1; 54 // The induction variable is incremented by this number (could be negative) 55 // in each iteration. 56 int64_t step_size = -1; 57 int64_t loop_bound = -1; 58 }; 59 60 // Indicates whether a parsed while loop is static or dynamic. If the loop is 61 // static, it contains a value for StaticLoopInfo; otherwise the loop is 62 // dynamic. We consider a loop dynamic if its induction variable's initial 63 // value or the loop bound's value depends on the while's parent computation's 64 // parameter. 65 struct ParsedWhileLoop { 66 std::optional<ParsedStaticWhileLoop> static_while_loop; is_dynamicParsedWhileLoop67 bool is_dynamic() const { return !static_while_loop.has_value(); } 68 }; 69 constexpr ParsedWhileLoop kParsedDynamicWhileLoop = ParsedWhileLoop(); 70 71 // Tries to parse a while loop using a set of predefined patterns. 72 // Returns the parsing result. 73 std::optional<ParsedWhileLoop> PatternMatchParseWhileLoop( 74 HloInstruction* while_op); 75 76 // Responsible for evaluating HLO and obtain literal as the evaluation results. 77 // 78 // This class is not thread-safe. 79 class HloEvaluator : public DfsHloVisitorWithDefault { 80 public: 81 // Only evaluate up to max_loop_iterations per while-loop execution if 82 // specified. 83 explicit HloEvaluator(int64_t max_loop_iterations = -1); 84 85 // Called by the evaluator to create an embedded evaluator to execute a 86 // sub-region of control flow. Subclasses should override this to return an 87 // instance of the subclass instead. CreateEmbedded(int64_t max_loop_iterations)88 virtual std::unique_ptr<HloEvaluator> CreateEmbedded( 89 int64_t max_loop_iterations) { 90 return std::make_unique<HloEvaluator>(max_loop_iterations); 91 } 92 93 // Evaluates an HLO module and an array of pointers to literals. Returns the 94 // evaluated result as a literal if successful. 95 // 96 // Precondition: The indices of arg_literals correspond to the parameter 97 // numbers of the HLO parameters in the computation. See comment below for an 98 // example. 99 // 100 // (Dummy template arg is to reduce the overloading priority of one overload 101 // so that Evaluate(module, {}) resolves unambiguously.) Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)102 StatusOr<Literal> Evaluate(const HloModule& module, 103 absl::Span<const Literal* const> arg_literals) { 104 return Evaluate(*module.entry_computation(), arg_literals); 105 } 106 template <typename Dummy = void> Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)107 StatusOr<Literal> Evaluate(const HloModule& module, 108 absl::Span<const Literal> arg_literals) { 109 return Evaluate(*module.entry_computation(), arg_literals); 110 } 111 112 // Evaluates an HLO computation and an array of pointers to literals. 113 // Returns the evaluated result as a literal if successful. 114 // Precondition: The indices of arg_literals correspond to the parameter 115 // numbers of the HLO parameters in the computation. For e.g., consider the 116 // following graph: 117 // 118 // * 119 // / \ 120 // + Parameter1 121 // / \ 122 // / \ 123 // Parameter0 Constant 124 // 125 // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number 126 // 1 in this computation. The input literals array will then have its first 127 // literal map to Parameter0 and the second map to Parameter1. 128 // 129 // (Dummy template arg is to reduce the overloading priority of one overload 130 // so that Evaluate(module, {}) resolves unambiguously.) 131 StatusOr<Literal> Evaluate(const HloComputation& computation, 132 absl::Span<const Literal* const> arg_literals); 133 template <typename Dummy = void> Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)134 StatusOr<Literal> Evaluate(const HloComputation& computation, 135 absl::Span<const Literal> arg_literals) { 136 std::vector<const Literal*> arg_literal_ptrs; 137 for (const auto& l : arg_literals) { 138 arg_literal_ptrs.push_back(&l); 139 } 140 return Evaluate(computation, arg_literal_ptrs); 141 } 142 143 // Gets the value of running a single HLO instruction. 144 // 145 // This function may recursively evaluate the dependency of this instruction 146 // within its parent computation until it encounters something that cannot be 147 // evaluated, such as an Infeed or a Parameter instruction. 148 // It makes best effort to partially evaluate a dependency if possible. 149 StatusOr<Literal> Evaluate( 150 HloInstruction* instruction, 151 bool recursively_evaluate_nonconstant_operands = false); 152 153 // Same as Evaluate, except returning false on error and accepts an output 154 // pointer. 155 bool TryEvaluate(HloInstruction* instruction, Literal* result, 156 bool recursively_evaluate_nonconstant_operands = false); 157 158 // Evaluates a single HLO instruction, substituting the given literals for 159 // some of the instruction's operands. 160 // 161 // For example, given instruction = op(A, B, C) and the map 162 // {A = x, C = y}, this evaluates op(x, B, y). 163 StatusOr<Literal> EvaluateWithSubstitutions( 164 const HloInstruction* instruction, 165 const absl::flat_hash_map<const HloInstruction*, const Literal*>& 166 substitutions); 167 168 StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode, 169 const Literal& lhs, 170 const Literal& rhs); 171 172 StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode, 173 const Literal& operand); 174 175 StatusOr<Literal> EvaluateElementwiseTernaryOp(HloOpcode opcode, 176 const Literal& lhs, 177 const Literal& rhs, 178 const Literal& ehs); 179 180 StatusOr<Literal> EvaluateElementwiseCompareOp(ComparisonDirection direction, 181 const Literal& lhs, 182 const Literal& rhs); 183 184 StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers, 185 const PrecisionConfig& precision_config, 186 const Literal& lhs, const Literal& rhs); 187 set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)188 void set_dynamic_dimension_inference( 189 DynamicDimensionInference* dynamic_dimension_inference) { 190 dynamic_dimension_inference_ = dynamic_dimension_inference; 191 } 192 dynamic_dimension_inference()193 DynamicDimensionInference* dynamic_dimension_inference() { 194 return dynamic_dimension_inference_; 195 } 196 197 // Enable the fast path for certain operations like dot or convolution. set_use_fast_path(bool value)198 void set_use_fast_path(bool value) { use_fast_path_ = value; } 199 200 // Handles evaluation of a custom-call op. 201 // Operand literals are provided in |operands| and implementations must 202 // populate |output| before returning. 203 using CustomCallHandler = std::function<StatusOr<Literal>( 204 HloInstruction* custom_call, absl::Span<const Literal*> operands)>; 205 206 // Sets a handler that is called during evaluation for custom-call ops. 207 // If no handler is defined the default error behavior will occur. The handler 208 // will be provided evaluated literals for all operands and is expected to 209 // return an output literal of the appropriate shape. set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)210 void set_custom_call_handler( 211 std::function<StatusOr<Literal>(HloInstruction* custom_call, 212 absl::Span<const Literal*> operands)> 213 handler) { 214 custom_call_handler_ = std::move(handler); 215 } 216 217 // Returns the result of a matrix multiply `lhs x rhs`. 218 static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D( 219 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs); 220 static std::unique_ptr<Array2D<float>> MatmulArray2D( 221 const Array2D<float>& lhs, const Array2D<float>& rhs); 222 static std::unique_ptr<Array2D<double>> MatmulArray2D( 223 const Array2D<double>& lhs, const Array2D<double>& rhs); 224 static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D( 225 const Array2D<std::complex<float>>& lhs, 226 const Array2D<std::complex<float>>& rhs); 227 static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D( 228 const Array2D<std::complex<double>>& lhs, 229 const Array2D<std::complex<double>>& rhs); 230 static std::unique_ptr<Array2D<int32_t>> MatmulArray2D( 231 const Array2D<int32_t>& lhs, const Array2D<int32_t>& rhs); 232 233 protected: 234 // Evaluates the given instruction, and stores the evaluation result in the 235 // evaluated_ map. 236 // When a non-empty shape_index is given, the instruction may be partially 237 // evaluated at the given shape_index and the rest of the result could be 238 // marked as undetermined unless it has been previously evaluated using 239 // EvaluateInternal. Such partial evaluation reduces the computation and 240 // memory overhead in cases where we need only one tuple element by avoiding 241 // the evaluation of a full tuple. 242 Status EvaluateInternal( 243 HloInstruction* instruction, const ShapeIndex& shape_index = {}, 244 bool recursively_evaluate_nonconstant_operands = false); 245 // Make HloEvaluatorTypedVisitor a friend because it is logically part of this 246 // class. 247 // 248 // A straightforward implementation would be to make it a nested class 249 // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor 250 // lives as a separate class with its own header because its template gets 251 // instantiated many times and we want to use extern templates to shard out 252 // the compilation of those instantiations across multiple cc files. 253 template <typename ReturnT, typename ElementwiseT> 254 friend class HloEvaluatorTypedVisitor; 255 256 // Wraps around instruction handling to infer types before dispatching to 257 // the corresponding typed Visitor. DefaultAction(HloInstruction * hlo)258 Status DefaultAction(HloInstruction* hlo) override { 259 return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); 260 } 261 262 Status Preprocess(HloInstruction* hlo) override; 263 264 Status Postprocess(HloInstruction* hlo) override; 265 266 // Operations that are type-agnostic or always return a specific type, such as 267 // HandleIsFinite where boolean is always returned. 268 // 269 Status HandleBitcast(HloInstruction* bitcast) override; 270 271 Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; 272 273 Status HandleSetDimensionSize(HloInstruction* set_dimension_size) override; 274 275 Status HandleParameter(HloInstruction* parameter) override; 276 277 Status HandleInfeed(HloInstruction* infeed) override; 278 279 Status HandleConstant(HloInstruction* constant) override; 280 281 Status HandleConcatenate(HloInstruction* concatenate) override; 282 283 Status HandleReshape(HloInstruction* reshape) override; 284 285 Status HandleTranspose(HloInstruction* transpose) override; 286 287 Status HandleIsFinite(HloInstruction* is_finite) override; 288 289 Status HandleCompare(HloInstruction* compare) override; 290 291 Status HandleTuple(HloInstruction* tuple) override; 292 293 Status HandleFft(HloInstruction* fft) override; 294 295 Status HandleGather(HloInstruction* gather) override; 296 297 Status HandleScatter(HloInstruction* hlo) override; 298 299 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 300 301 Status HandleAsyncStart(HloInstruction* async_start) override; 302 303 Status HandleAsyncUpdate(HloInstruction* async_update) override; 304 305 Status HandleAsyncDone(HloInstruction* async_done) override; 306 307 Status HandleCopy(HloInstruction* copy) override; 308 309 Status HandleCopyStart(HloInstruction* copy_start) override; 310 311 Status HandleCopyDone(HloInstruction* copy_done) override; 312 313 Status HandleConditional(HloInstruction* conditional) override; 314 315 Status HandleCall(HloInstruction* call) override; 316 317 Status HandleFusion(HloInstruction* fusion) override; 318 319 Status HandleWhile(HloInstruction* while_hlo) override; 320 321 Status HandleSelect(HloInstruction* select) override; 322 323 Status HandleBroadcast(HloInstruction* broadcast) override; 324 325 Status HandleAfterAll(HloInstruction* after_all) override; 326 327 Status HandleAddDependency(HloInstruction* add_dependency) override; 328 329 Status HandleSort(HloInstruction* sort) override; 330 331 Status HandleReal(HloInstruction* real) override; 332 333 Status HandleImag(HloInstruction* imag) override; 334 335 Status HandleComplex(HloInstruction* complex) override; 336 337 Status HandleReduce(HloInstruction* reduce) override; 338 339 Status HandleReduceWindow(HloInstruction* hlo) override; 340 341 Status HandleCustomCall(HloInstruction* custom_call) override; 342 343 // Unsupported HLOs, note some of them (such as BatchNorm*) are typically 344 // expanded in a semantic-preserving way into other HLOs by adding expansion 345 // HLO pass to the HLO optimization pass during compilation, which can then be 346 // handled by the evaluator. HandleBatchNormGrad(HloInstruction * batch_norm_grad)347 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { 348 return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); 349 } HandleBatchNormInference(HloInstruction * batch_norm_inference)350 Status HandleBatchNormInference( 351 HloInstruction* batch_norm_inference) override { 352 return Unimplemented( 353 "BatchNormInference HLO is unsupported by the evaluator."); 354 } HandleBatchNormTraining(HloInstruction * batch_norm_training)355 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { 356 return Unimplemented( 357 "BatchNormTraining HLO is unsupported by the evaluator."); 358 } HandleOutfeed(HloInstruction * outfeed)359 Status HandleOutfeed(HloInstruction* outfeed) override { 360 return Unimplemented("Outfeed HLO is unsupported by the evaluator."); 361 } 362 363 // Returns the already-evaluated literal result for the instruction. 364 // 365 // A Constant instruction is considered evaluated and its literal will be 366 // returned directly without looking up the cache. 367 // 368 // Similarly, a Parameter instruction is considered evaluated and its literal 369 // is looked up in arg_literals. 370 // 371 // Crash with log if the given instruction has not been evaluated previously. GetEvaluatedLiteralFor(const HloInstruction * hlo)372 const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { 373 if (hlo->IsConstant()) { 374 return hlo->literal(); 375 } 376 if (hlo->opcode() == HloOpcode::kParameter && !arg_literals_.empty()) { 377 return *arg_literals_.at(hlo->parameter_number()); 378 } 379 380 auto it = evaluated_.find(hlo); 381 CHECK(it != evaluated_.end()) 382 << "could not find evaluated value for: " << hlo->ToString(); 383 return it->second; 384 } 385 386 // Returns true if the given hlo has been evaluated and cached. 387 bool IsAlreadyEvaluated(const HloInstruction* hlo, 388 const ShapeIndex& shape_index = {}) { 389 if (hlo->IsConstant()) { 390 return true; 391 } 392 if (hlo->opcode() == HloOpcode::kParameter && !arg_literals_.empty()) { 393 return true; 394 } 395 auto it = evaluated_.find(hlo); 396 if (it == evaluated_.end()) { 397 return false; 398 } 399 // We may evaluate some elements of a tuple-shaped instruction and mark 400 // the other elements as undetermined. This way we avoid the computation 401 // and memory overhead of evaluating a large tuple when only some elements 402 // are needed. By marking the other elements undetermined, we allow the 403 // evaluator to update the cached tuple literal when more elements are 404 // evaluated. 405 return it->second.IsDetermined(shape_index); 406 } 407 408 // Tracks the HLO instruction and its evaluated literal result. 409 // 410 // Parameters and constants aren't stored here, see implementation of 411 // GetEvaluatedLiteralFor. 412 // 413 // TODO(b/35950897): have better memory management here to free instructions 414 // that are no longer a parent for any other subsequent instruction in 415 // post-ordering. 416 // 417 // Must be cleared for each evaluation. 418 // 419 // Storing Literal in place requires the container to have pointer stability 420 // so we cannot use flat_hash_map any more. 421 absl::node_hash_map<const HloInstruction*, Literal> evaluated_; 422 // Set by EvaluateInternal and opportunitiscally used by the HandleXXX 423 // functions. When non-empty, the HandleXXX function may evaluate the 424 // instruction at only the given shape index. 425 ShapeIndex visitor_shape_index_; 426 bool enable_partial_evaluation_ = false; 427 428 // Use fast path that uses eigen in the evaluator. 429 bool use_fast_path_ = false; 430 431 private: 432 template <typename ReturnT, typename NativeT> ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)433 static StatusOr<Literal> ElementWiseUnaryOpImpl( 434 HloInstruction* instruction, 435 const std::function<ReturnT(NativeT)>& unary_op, 436 const Literal& operand_literal) { 437 const auto shape = instruction->shape(); 438 const auto* operand = instruction->operand(0); 439 TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); 440 441 Literal result(shape); 442 TF_RETURN_IF_ERROR( 443 result.Populate<ReturnT>([&](absl::Span<const int64_t> multi_index) { 444 return unary_op(operand_literal.Get<NativeT>(multi_index)); 445 })); 446 return std::move(result); 447 } 448 449 // Map from a primitive type to its associated (templated) DfsHloVisitor. 450 std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE]; 451 452 // Caches pointers to input literals, assuming they are in post-order. 453 // Literals are not owned by this class, and they must outlive the lifetime of 454 // each invocation to the Evaluate* method. 455 // Must be cleared for each evaluation. 456 std::vector<const Literal*> arg_literals_; 457 458 // Max loop iterations to execute with no maximum if negative. 459 int64_t max_loop_iterations_ = 0; 460 461 // Module-level seed handle. 462 uint64_t seed_ = 0; 463 // RNG engine. 464 std::minstd_rand0 engine_; 465 466 // DynamicDimensionInference is used to evaluate GetDimensionSize, which 467 // returns the dynamic dimension size of its operand. 468 DynamicDimensionInference* dynamic_dimension_inference_ = nullptr; 469 470 // Optional handler for custom_call ops. 471 std::function<StatusOr<Literal>(HloInstruction* custom_call, 472 absl::Span<const Literal*> operands)> 473 custom_call_handler_; 474 475 HloEvaluator(const HloEvaluator&) = delete; 476 HloEvaluator& operator=(const HloEvaluator&) = delete; 477 }; 478 479 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs, 480 const Array2D<float>& rhs); 481 } // namespace xla 482 483 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 484