1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "absl/container/node_hash_map.h" 23 #include "absl/memory/memory.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/array2d.h" 26 #include "tensorflow/compiler/xla/literal.h" 27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 28 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" 29 #include "tensorflow/compiler/xla/service/hlo_computation.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/shape_inference.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/util.h" 35 #include "tensorflow/compiler/xla/xla_data.pb.h" 36 #include "tensorflow/core/platform/macros.h" 37 38 namespace xla { 39 40 // Responsible for evaluating HLO and obtain literal as the evaluation results. 41 // 42 // This class is not thread-safe. 43 class HloEvaluator : public DfsHloVisitorWithDefault { 44 public: 45 // Only evaluate up to max_loop_iterations per while-loop execution if 46 // specified. 47 explicit HloEvaluator(int64 max_loop_iterations = -1); 48 49 // Evaluates an HLO module and an array of pointers to literals. Returns the 50 // evaluated result as a literal if successful. 51 // 52 // Precondition: The indices of arg_literals correspond to the parameter 53 // numbers of the HLO parameters in the computation. See comment below for an 54 // example. 55 // 56 // (Dummy template arg is to reduce the overloading priority of one overload 57 // so that Evaluate(module, {}) resolves unambiguously.) Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)58 StatusOr<Literal> Evaluate(const HloModule& module, 59 absl::Span<const Literal* const> arg_literals) { 60 return Evaluate(*module.entry_computation(), arg_literals); 61 } 62 template <typename Dummy = void> Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)63 StatusOr<Literal> Evaluate(const HloModule& module, 64 absl::Span<const Literal> arg_literals) { 65 return Evaluate(*module.entry_computation(), arg_literals); 66 } 67 68 // Evaluates an HLO computation and an array of pointers to literals. 69 // Returns the evaluated result as a literal if successful. 70 // Precondition: The indices of arg_literals correspond to the parameter 71 // numbers of the HLO parameters in the computation. For e.g., consider the 72 // following graph: 73 // 74 // * 75 // / \ 76 // + Parameter1 77 // / \ 78 // / \ 79 // Parameter0 Constant 80 // 81 // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number 82 // 1 in this computation. The input literals array will then have its first 83 // literal map to Parameter0 and the second map to Parameter1. 84 // 85 // (Dummy template arg is to reduce the overloading priority of one overload 86 // so that Evaluate(module, {}) resolves unambiguously.) 87 StatusOr<Literal> Evaluate(const HloComputation& computation, 88 absl::Span<const Literal* const> arg_literals); 89 template <typename Dummy = void> Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)90 StatusOr<Literal> Evaluate(const HloComputation& computation, 91 absl::Span<const Literal> arg_literals) { 92 std::vector<const Literal*> arg_literal_ptrs; 93 for (const auto& l : arg_literals) { 94 arg_literal_ptrs.push_back(&l); 95 } 96 return Evaluate(computation, arg_literal_ptrs); 97 } 98 99 // Gets the value of running a single HLO instruction. 100 // 101 // All of the operands to this instruction must be constants. 102 StatusOr<Literal> Evaluate(HloInstruction* instruction); 103 104 // Same as Evaluate, except returning false on error and accepts an output 105 // pointer. 106 bool TryEvaluate(HloInstruction* instruction, Literal* result); 107 108 // Evaluates a single HLO instruction, substituting the given literals for 109 // some of the instruction's operands. 110 // 111 // For example, given instruction = op(A, B, C) and the map 112 // {A = x, C = y}, this evaluates op(x, B, y). 113 StatusOr<Literal> EvaluateWithSubstitutions( 114 const HloInstruction* instruction, 115 const std::unordered_map<const HloInstruction*, const Literal*>& 116 substitutions); 117 118 StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode, 119 const Literal& lhs, 120 const Literal& rhs); 121 122 StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode, 123 const Literal& operand); 124 125 StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers, 126 const PrecisionConfig& precision_config, 127 const Literal& lhs, const Literal& rhs); 128 set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)129 void set_dynamic_dimension_inference( 130 DynamicDimensionInference* dynamic_dimension_inference) { 131 dynamic_dimension_inference_ = dynamic_dimension_inference; 132 } 133 134 // Enable the fast path for certain operations like dot or convolution. set_use_fast_path(bool value)135 void set_use_fast_path(bool value) { use_fast_path_ = value; } 136 137 // Handles evaluation of a custom-call op. 138 // Operand literals are provided in |operands| and implementations must 139 // populate |output| before returning. 140 using CustomCallHandler = std::function<StatusOr<Literal>( 141 HloInstruction* custom_call, absl::Span<const Literal*> operands)>; 142 143 // Sets a handler that is called during evaluation for custom-call ops. 144 // If no handler is defined the default error behavior will occur. The handler 145 // will be provided evaluated literals for all operands and is expected to 146 // return an output literal of the appropriate shape. set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)147 void set_custom_call_handler( 148 std::function<StatusOr<Literal>(HloInstruction* custom_call, 149 absl::Span<const Literal*> operands)> 150 handler) { 151 custom_call_handler_ = std::move(handler); 152 } 153 154 // Returns the result of a matrix multiply `lhs x rhs`. 155 static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D( 156 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs); 157 static std::unique_ptr<Array2D<float>> MatmulArray2D( 158 const Array2D<float>& lhs, const Array2D<float>& rhs); 159 static std::unique_ptr<Array2D<double>> MatmulArray2D( 160 const Array2D<double>& lhs, const Array2D<double>& rhs); 161 162 protected: 163 // Make HloEvaluatorTypedVisitor a friend because it is logically part of this 164 // class. 165 // 166 // A straightforward implementation would be to make it a nested class 167 // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor 168 // lives as a separate class with its own header because its template gets 169 // instantiated many times and we want to use extern templates to shard out 170 // the compilation of those instantiations across multiple cc files. 171 template <typename ReturnT, typename ElementwiseT> 172 friend class HloEvaluatorTypedVisitor; 173 174 // Wraps around instruction handling to infer types before dispatching to 175 // the corresponding typed Visitor. DefaultAction(HloInstruction * hlo)176 Status DefaultAction(HloInstruction* hlo) override { 177 return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); 178 } 179 180 Status Preprocess(HloInstruction* hlo) override; 181 182 Status Postprocess(HloInstruction* hlo) override; 183 184 // Operations that are type-agnostic or always return a specific type, such as 185 // HandleIsFinite where boolean is always returned. 186 // 187 Status HandleBitcast(HloInstruction* bitcast) override; 188 189 Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; 190 191 Status HandleParameter(HloInstruction* parameter) override; 192 193 Status HandleConstant(HloInstruction* constant) override; 194 195 Status HandleConcatenate(HloInstruction* concatenate) override; 196 197 Status HandleReshape(HloInstruction* reshape) override; 198 199 Status HandleTranspose(HloInstruction* transpose) override; 200 201 Status HandleIsFinite(HloInstruction* is_finite) override; 202 203 Status HandleCompare(HloInstruction* compare) override; 204 205 Status HandleTuple(HloInstruction* tuple) override; 206 207 Status HandleGather(HloInstruction* gather) override; 208 209 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 210 211 Status HandleCopy(HloInstruction* copy) override; 212 213 Status HandleConditional(HloInstruction* conditional) override; 214 215 Status HandleCall(HloInstruction* call) override; 216 217 Status HandleFusion(HloInstruction* fusion) override; 218 219 Status HandleWhile(HloInstruction* while_hlo) override; 220 221 Status HandleSelect(HloInstruction* select) override; 222 223 Status HandleTupleSelect(HloInstruction* tuple_select) override; 224 225 Status HandleBroadcast(HloInstruction* broadcast) override; 226 227 Status HandleAfterAll(HloInstruction* after_all) override; 228 229 Status HandleAddDependency(HloInstruction* add_dependency) override; 230 231 Status HandleSort(HloInstruction* sort) override; 232 233 Status HandleReal(HloInstruction* real) override; 234 235 Status HandleImag(HloInstruction* imag) override; 236 237 Status HandleComplex(HloInstruction* complex) override; 238 239 Status HandleReduce(HloInstruction* reduce) override; 240 241 Status HandleCustomCall(HloInstruction* custom_call) override; 242 243 // Unsupported HLOs, note some of them (such as BatchNorm*) are typically 244 // expanded in a semantic-preserving way into other HLOs by adding exanpsion 245 // HLO pass to the HLO optimization pass during compilation, which can then be 246 // handled by the evaluator. HandleBatchNormGrad(HloInstruction * batch_norm_grad)247 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { 248 return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); 249 }; HandleBatchNormInference(HloInstruction * batch_norm_inference)250 Status HandleBatchNormInference( 251 HloInstruction* batch_norm_inference) override { 252 return Unimplemented( 253 "BatchNormInference HLO is unsupported by the evaluator."); 254 }; HandleBatchNormTraining(HloInstruction * batch_norm_training)255 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { 256 return Unimplemented( 257 "BatchNormTraining HLO is unsupported by the evaluator."); 258 }; HandleInfeed(HloInstruction * infeed)259 Status HandleInfeed(HloInstruction* infeed) override { 260 return Unimplemented("Infeed HLO is unsupported by the evaluator."); 261 }; HandleOutfeed(HloInstruction * outfeed)262 Status HandleOutfeed(HloInstruction* outfeed) override { 263 return Unimplemented("Outfeed HLO is unsupported by the evaluator."); 264 }; 265 266 // Returns the already-evaluated literal result for the instruction. 267 // 268 // A Constant instruction is considered evaluated and its literal will be 269 // returned directly without looking up the cache. 270 // 271 // Similarly, a Parameter instruction is considered evaluated and its literal 272 // is looked up in arg_literals. 273 // 274 // Crash with log if the given instruction has not been evaluated previously. GetEvaluatedLiteralFor(const HloInstruction * hlo)275 const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { 276 if (hlo->IsConstant()) { 277 return hlo->literal(); 278 } 279 if (hlo->opcode() == HloOpcode::kParameter) { 280 return *arg_literals_.at(hlo->parameter_number()); 281 } 282 auto it = evaluated_.find(hlo); 283 CHECK(it != evaluated_.end()) 284 << "could not find evaluated value for: " << hlo->ToString(); 285 return it->second; 286 } 287 288 // Tracks the HLO instruction and its evaluated literal result. 289 // 290 // Parameters and constants aren't stored here, see implementation of 291 // GetEvaluatedLiteralFor. 292 // 293 // TODO(b/35950897): have better memory management here to free instructions 294 // that are no longer a parent for any other subsequent instruction in 295 // post-orderring. 296 // 297 // Must be cleared for each evaluation. 298 // 299 // Storing Literal in place requires the container to have pointer stability 300 // so we cannot use flat_hash_map any more. 301 absl::node_hash_map<const HloInstruction*, Literal> evaluated_; 302 303 // Use fast path that uses eigen in the evaluator. 304 bool use_fast_path_ = false; 305 306 private: 307 template <typename ReturnT, typename NativeT> ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)308 static StatusOr<Literal> ElementWiseUnaryOpImpl( 309 HloInstruction* instruction, 310 const std::function<ReturnT(NativeT)>& unary_op, 311 const Literal& operand_literal) { 312 const auto shape = instruction->shape(); 313 const auto* operand = instruction->operand(0); 314 TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); 315 316 Literal result(shape); 317 TF_RETURN_IF_ERROR( 318 result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { 319 return unary_op(operand_literal.Get<NativeT>(multi_index)); 320 })); 321 return std::move(result); 322 } 323 324 // Map from a primitive type to its associated (templated) DfsHloVisitor. 325 std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE]; 326 327 // Caches pointers to input literals, assuming they are in post-order. 328 // Literals are not owned by this class, and they must outlive the lifetime of 329 // each invocation to the Evaluate* method. 330 // Must be cleared for each evaluation. 331 std::vector<const Literal*> arg_literals_; 332 333 // Max loop iterations to execute with no maximum if negative. 334 int64 max_loop_iterations_ = 0; 335 336 // Module-level seed handle. 337 uint64 seed_ = 0; 338 // RNG engine. 339 std::minstd_rand0 engine_; 340 341 // DynamicDimensionInference is used to evaluate GetDimensionSize, which 342 // returns the dynamic dimension size of its operand. 343 DynamicDimensionInference* dynamic_dimension_inference_ = nullptr; 344 345 // Optional handler for custom_call ops. 346 std::function<StatusOr<Literal>(HloInstruction* custom_call, 347 absl::Span<const Literal*> operands)> 348 custom_call_handler_; 349 350 TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); 351 }; 352 353 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs, 354 const Array2D<float>& rhs); 355 } // namespace xla 356 357 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ 358