1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 18 19 #define _USE_MATH_DEFINES 20 21 #include <functional> 22 #include <memory> 23 24 #include "absl/container/node_hash_map.h" 25 #include "absl/memory/memory.h" 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/array2d.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 30 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" 31 #include "tensorflow/compiler/xla/service/hlo_computation.h" 32 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 33 #include "tensorflow/compiler/xla/service/hlo_module.h" 34 #include "tensorflow/compiler/xla/service/shape_inference.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/macros.h" 39 40 namespace xla { 41 42 // Responsible for evaluating HLO and obtain literal as the evaluation results. 43 // 44 // This class is not thread-safe. 45 class HloEvaluator : public DfsHloVisitorWithDefault { 46 public: 47 // Only evaluate up to max_loop_iterations per while-loop execution if 48 // specified. 49 explicit HloEvaluator(int64 max_loop_iterations = -1); 50 51 // Evaluates an HLO module and an array of pointers to literals. Returns the 52 // evaluated result as a literal if successful. 53 // 54 // Precondition: The indices of arg_literals correspond to the parameter 55 // numbers of the HLO parameters in the computation. See comment below for an 56 // example. 57 // 58 // (Dummy template arg is to reduce the overloading priority of one overload 59 // so that Evaluate(module, {}) resolves unambiguously.) Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)60 StatusOr<Literal> Evaluate(const HloModule& module, 61 absl::Span<const Literal* const> arg_literals) { 62 return Evaluate(*module.entry_computation(), arg_literals); 63 } 64 template <typename Dummy = void> Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)65 StatusOr<Literal> Evaluate(const HloModule& module, 66 absl::Span<const Literal> arg_literals) { 67 return Evaluate(*module.entry_computation(), arg_literals); 68 } 69 70 // Evaluates an HLO computation and an array of pointers to literals. 71 // Returns the evaluated result as a literal if successful. 72 // Precondition: The indices of arg_literals correspond to the parameter 73 // numbers of the HLO parameters in the computation. For e.g., consider the 74 // following graph: 75 // 76 // * 77 // / \ 78 // + Parameter1 79 // / \ 80 // / \ 81 // Parameter0 Constant 82 // 83 // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number 84 // 1 in this computation. The input literals array will then have its first 85 // literal map to Parameter0 and the second map to Parameter1. 86 // 87 // (Dummy template arg is to reduce the overloading priority of one overload 88 // so that Evaluate(module, {}) resolves unambiguously.) 89 StatusOr<Literal> Evaluate(const HloComputation& computation, 90 absl::Span<const Literal* const> arg_literals); 91 template <typename Dummy = void> Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)92 StatusOr<Literal> Evaluate(const HloComputation& computation, 93 absl::Span<const Literal> arg_literals) { 94 std::vector<const Literal*> arg_literal_ptrs; 95 for (const auto& l : arg_literals) { 96 arg_literal_ptrs.push_back(&l); 97 } 98 return Evaluate(computation, arg_literal_ptrs); 99 } 100 101 // Gets the value of running a single HLO instruction. 102 // 103 // All of the operands to this instruction must be constants. 104 StatusOr<Literal> Evaluate(HloInstruction* instruction); 105 106 // Same as Evaluate, except returning false on error and accepts an output 107 // pointer. 108 bool TryEvaluate(HloInstruction* instruction, Literal* result); 109 110 // Evaluates a single HLO instruction, substituting the given literals for 111 // some of the instruction's operands. 112 // 113 // For example, given instruction = op(A, B, C) and the map 114 // {A = x, C = y}, this evaluates op(x, B, y). 115 StatusOr<Literal> EvaluateWithSubstitutions( 116 const HloInstruction* instruction, 117 const std::unordered_map<const HloInstruction*, const Literal*>& 118 substitutions); 119 120 StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode, 121 const Literal& lhs, 122 const Literal& rhs); 123 124 StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode, 125 const Literal& operand); 126 127 StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers, 128 const PrecisionConfig& precision_config, 129 const Literal& lhs, const Literal& rhs); 130 set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)131 void set_dynamic_dimension_inference( 132 DynamicDimensionInference* dynamic_dimension_inference) { 133 dynamic_dimension_inference_ = dynamic_dimension_inference; 134 } 135 dynamic_dimension_inference()136 DynamicDimensionInference* dynamic_dimension_inference() { 137 return dynamic_dimension_inference_; 138 } 139 140 // Enable the fast path for certain operations like dot or convolution. set_use_fast_path(bool value)141 void set_use_fast_path(bool value) { use_fast_path_ = value; } 142 143 // Handles evaluation of a custom-call op. 144 // Operand literals are provided in |operands| and implementations must 145 // populate |output| before returning. 146 using CustomCallHandler = std::function<StatusOr<Literal>( 147 HloInstruction* custom_call, absl::Span<const Literal*> operands)>; 148 149 // Sets a handler that is called during evaluation for custom-call ops. 150 // If no handler is defined the default error behavior will occur. The handler 151 // will be provided evaluated literals for all operands and is expected to 152 // return an output literal of the appropriate shape. set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)153 void set_custom_call_handler( 154 std::function<StatusOr<Literal>(HloInstruction* custom_call, 155 absl::Span<const Literal*> operands)> 156 handler) { 157 custom_call_handler_ = std::move(handler); 158 } 159 160 // Returns the result of a matrix multiply `lhs x rhs`. 161 static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D( 162 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs); 163 static std::unique_ptr<Array2D<float>> MatmulArray2D( 164 const Array2D<float>& lhs, const Array2D<float>& rhs); 165 static std::unique_ptr<Array2D<double>> MatmulArray2D( 166 const Array2D<double>& lhs, const Array2D<double>& rhs); 167 static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D( 168 const Array2D<std::complex<float>>& lhs, 169 const Array2D<std::complex<float>>& rhs); 170 static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D( 171 const Array2D<std::complex<double>>& lhs, 172 const Array2D<std::complex<double>>& rhs); 173 static std::unique_ptr<Array2D<int32>> MatmulArray2D( 174 const Array2D<int32>& lhs, const Array2D<int32>& rhs); 175 176 protected: 177 // Make HloEvaluatorTypedVisitor a friend because it is logically part of this 178 // class. 179 // 180 // A straightforward implementation would be to make it a nested class 181 // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor 182 // lives as a separate class with its own header because its template gets 183 // instantiated many times and we want to use extern templates to shard out 184 // the compilation of those instantiations across multiple cc files. 185 template <typename ReturnT, typename ElementwiseT> 186 friend class HloEvaluatorTypedVisitor; 187 188 // Wraps around instruction handling to infer types before dispatching to 189 // the corresponding typed Visitor. DefaultAction(HloInstruction * hlo)190 Status DefaultAction(HloInstruction* hlo) override { 191 return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); 192 } 193 194 Status Preprocess(HloInstruction* hlo) override; 195 196 Status Postprocess(HloInstruction* hlo) override; 197 198 // Operations that are type-agnostic or always return a specific type, such as 199 // HandleIsFinite where boolean is always returned. 200 // 201 Status HandleBitcast(HloInstruction* bitcast) override; 202 203 Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; 204 205 Status HandleSetDimensionSize(HloInstruction* set_dimension_size) override; 206 207 Status HandleParameter(HloInstruction* parameter) override; 208 209 Status HandleConstant(HloInstruction* constant) override; 210 211 Status HandleConcatenate(HloInstruction* concatenate) override; 212 213 Status HandleReshape(HloInstruction* reshape) override; 214 215 Status HandleTranspose(HloInstruction* transpose) override; 216 217 Status HandleIsFinite(HloInstruction* is_finite) override; 218 219 Status HandleCompare(HloInstruction* compare) override; 220 221 Status HandleTuple(HloInstruction* tuple) override; 222 223 Status HandleFft(HloInstruction* fft) override; 224 225 Status HandleGather(HloInstruction* gather) override; 226 227 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 228 229 Status HandleCopy(HloInstruction* copy) override; 230 231 Status HandleCopyStart(HloInstruction* copy_start) override; 232 233 Status HandleCopyDone(HloInstruction* copy_done) override; 234 235 Status HandleConditional(HloInstruction* conditional) override; 236 237 Status HandleCall(HloInstruction* call) override; 238 239 Status HandleFusion(HloInstruction* fusion) override; 240 241 Status HandleWhile(HloInstruction* while_hlo) override; 242 243 Status HandleSelect(HloInstruction* select) override; 244 245 Status HandleTupleSelect(HloInstruction* tuple_select) override; 246 247 Status HandleBroadcast(HloInstruction* broadcast) override; 248 249 Status HandleAfterAll(HloInstruction* after_all) override; 250 251 Status HandleAddDependency(HloInstruction* add_dependency) override; 252 253 Status HandleSort(HloInstruction* sort) override; 254 255 Status HandleReal(HloInstruction* real) override; 256 257 Status HandleImag(HloInstruction* imag) override; 258 259 Status HandleComplex(HloInstruction* complex) override; 260 261 Status HandleReduce(HloInstruction* reduce) override; 262 263 Status HandleReduceWindow(HloInstruction* hlo) override; 264 265 Status HandleCustomCall(HloInstruction* custom_call) override; 266 267 // Unsupported HLOs, note some of them (such as BatchNorm*) are typically 268 // expanded in a semantic-preserving way into other HLOs by adding expansion 269 // HLO pass to the HLO optimization pass during compilation, which can then be 270 // handled by the evaluator. HandleBatchNormGrad(HloInstruction * batch_norm_grad)271 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { 272 return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); 273 }; HandleBatchNormInference(HloInstruction * batch_norm_inference)274 Status HandleBatchNormInference( 275 HloInstruction* batch_norm_inference) override { 276 return Unimplemented( 277 "BatchNormInference HLO is unsupported by the evaluator."); 278 }; HandleBatchNormTraining(HloInstruction * batch_norm_training)279 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { 280 return Unimplemented( 281 "BatchNormTraining HLO is unsupported by the evaluator."); 282 }; HandleInfeed(HloInstruction * infeed)283 Status HandleInfeed(HloInstruction* infeed) override { 284 return Unimplemented("Infeed HLO is unsupported by the evaluator."); 285 }; HandleOutfeed(HloInstruction * outfeed)286 Status HandleOutfeed(HloInstruction* outfeed) override { 287 return Unimplemented("Outfeed HLO is unsupported by the evaluator."); 288 }; 289 290 // Returns the already-evaluated literal result for the instruction. 291 // 292 // A Constant instruction is considered evaluated and its literal will be 293 // returned directly without looking up the cache. 294 // 295 // Similarly, a Parameter instruction is considered evaluated and its literal 296 // is looked up in arg_literals. 297 // 298 // Crash with log if the given instruction has not been evaluated previously. GetEvaluatedLiteralFor(const HloInstruction * hlo)299 const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { 300 if (hlo->IsConstant()) { 301 return hlo->literal(); 302 } 303 if (hlo->opcode() == HloOpcode::kParameter) { 304 return *arg_literals_.at(hlo->parameter_number()); 305 } 306 auto it = evaluated_.find(hlo); 307 CHECK(it != evaluated_.end()) 308 << "could not find evaluated value for: " << hlo->ToString(); 309 return it->second; 310 } 311 312 // Tracks the HLO instruction and its evaluated literal result. 313 // 314 // Parameters and constants aren't stored here, see implementation of 315 // GetEvaluatedLiteralFor. 316 // 317 // TODO(b/35950897): have better memory management here to free instructions 318 // that are no longer a parent for any other subsequent instruction in 319 // post-ordering. 320 // 321 // Must be cleared for each evaluation. 322 // 323 // Storing Literal in place requires the container to have pointer stability 324 // so we cannot use flat_hash_map any more. 325 absl::node_hash_map<const HloInstruction*, Literal> evaluated_; 326 327 // Use fast path that uses eigen in the evaluator. 328 bool use_fast_path_ = false; 329 330 private: 331 template <typename ReturnT, typename NativeT> ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)332 static StatusOr<Literal> ElementWiseUnaryOpImpl( 333 HloInstruction* instruction, 334 const std::function<ReturnT(NativeT)>& unary_op, 335 const Literal& operand_literal) { 336 const auto shape = instruction->shape(); 337 const auto* operand = instruction->operand(0); 338 TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); 339 340 Literal result(shape); 341 TF_RETURN_IF_ERROR( 342 result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { 343 return unary_op(operand_literal.Get<NativeT>(multi_index)); 344 })); 345 return std::move(result); 346 } 347 348 // Map from a primitive type to its associated (templated) DfsHloVisitor. 349 std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE]; 350 351 // Caches pointers to input literals, assuming they are in post-order. 352 // Literals are not owned by this class, and they must outlive the lifetime of 353 // each invocation to the Evaluate* method. 354 // Must be cleared for each evaluation. 355 std::vector<const Literal*> arg_literals_; 356 357 // Max loop iterations to execute with no maximum if negative. 358 int64 max_loop_iterations_ = 0; 359 360 // Module-level seed handle. 361 uint64 seed_ = 0; 362 // RNG engine. 363 std::minstd_rand0 engine_; 364 365 // DynamicDimensionInference is used to evaluate GetDimensionSize, which 366 // returns the dynamic dimension size of its operand. 367 DynamicDimensionInference* dynamic_dimension_inference_ = nullptr; 368 369 // Optional handler for custom_call ops. 370 std::function<StatusOr<Literal>(HloInstruction* custom_call, 371 absl::Span<const Literal*> operands)> 372 custom_call_handler_; 373 374 TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); 375 }; 376 377 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs, 378 const Array2D<float>& rhs); 379 } // namespace xla 380 381 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 382