• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
18 
19 #define _USE_MATH_DEFINES
20 
21 #include <functional>
22 #include <memory>
23 #include <optional>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/node_hash_map.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/array2d.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/literal_util.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/shape_inference.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 
42 namespace xla {
43 
44 // Represents a parsed static while loop. We normalize the loop representation
45 // so that it starts from the induction_var_init_value and increments by
46 // step_size until it exceeds or goes below loop_bound.
47 struct ParsedStaticWhileLoop {
48   // The number of iterations to be executed.
49   int64_t trip_count = -1;
50   // The tuple index of the induction variable in the while argument tuple.
51   int64_t induction_var_index = -1;
52   // The induction variable's initial value.
53   int64_t induction_var_init_value = -1;
54   // The induction variable is incremented by this number (could be negative)
55   // in each iteration.
56   int64_t step_size = -1;
57   int64_t loop_bound = -1;
58 };
59 
60 // Indicates whether a parsed while loop is static or dynamic. If the loop is
61 // static, it contains a value for StaticLoopInfo; otherwise the loop is
62 // dynamic. We consider a loop dynamic if its induction variable's initial
63 // value or the loop bound's value depends on the while's parent computation's
64 // parameter.
65 struct ParsedWhileLoop {
66   std::optional<ParsedStaticWhileLoop> static_while_loop;
is_dynamicParsedWhileLoop67   bool is_dynamic() const { return !static_while_loop.has_value(); }
68 };
69 constexpr ParsedWhileLoop kParsedDynamicWhileLoop = ParsedWhileLoop();
70 
71 // Tries to parse a while loop using a set of predefined patterns.
72 // Returns the parsing result.
73 std::optional<ParsedWhileLoop> PatternMatchParseWhileLoop(
74     HloInstruction* while_op);
75 
76 // Responsible for evaluating HLO and obtain literal as the evaluation results.
77 //
78 // This class is not thread-safe.
79 class HloEvaluator : public DfsHloVisitorWithDefault {
80  public:
81   // Only evaluate up to max_loop_iterations per while-loop execution if
82   // specified.
83   explicit HloEvaluator(int64_t max_loop_iterations = -1);
84 
85   // Called by the evaluator to create an embedded evaluator to execute a
86   // sub-region of control flow. Subclasses should override this to return an
87   // instance of the subclass instead.
CreateEmbedded(int64_t max_loop_iterations)88   virtual std::unique_ptr<HloEvaluator> CreateEmbedded(
89       int64_t max_loop_iterations) {
90     return std::make_unique<HloEvaluator>(max_loop_iterations);
91   }
92 
93   // Evaluates an HLO module and an array of pointers to literals.  Returns the
94   // evaluated result as a literal if successful.
95   //
96   // Precondition: The indices of arg_literals correspond to the parameter
97   // numbers of the HLO parameters in the computation. See comment below for an
98   // example.
99   //
100   // (Dummy template arg is to reduce the overloading priority of one overload
101   // so that Evaluate(module, {}) resolves unambiguously.)
Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)102   StatusOr<Literal> Evaluate(const HloModule& module,
103                              absl::Span<const Literal* const> arg_literals) {
104     return Evaluate(*module.entry_computation(), arg_literals);
105   }
106   template <typename Dummy = void>
Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)107   StatusOr<Literal> Evaluate(const HloModule& module,
108                              absl::Span<const Literal> arg_literals) {
109     return Evaluate(*module.entry_computation(), arg_literals);
110   }
111 
112   // Evaluates an HLO computation and an array of pointers to literals.
113   // Returns the evaluated result as a literal if successful.
114   // Precondition: The indices of arg_literals correspond to the parameter
115   // numbers of the HLO parameters in the computation. For e.g., consider the
116   // following graph:
117   //
118   //                *
119   //            /       \
120   //            +     Parameter1
121   //        /      \
122   //       /        \
123   //    Parameter0  Constant
124   //
125   // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
126   // 1 in this computation. The input literals array will then have its first
127   // literal map to Parameter0 and the second map to Parameter1.
128   //
129   // (Dummy template arg is to reduce the overloading priority of one overload
130   // so that Evaluate(module, {}) resolves unambiguously.)
131   StatusOr<Literal> Evaluate(const HloComputation& computation,
132                              absl::Span<const Literal* const> arg_literals);
133   template <typename Dummy = void>
Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)134   StatusOr<Literal> Evaluate(const HloComputation& computation,
135                              absl::Span<const Literal> arg_literals) {
136     std::vector<const Literal*> arg_literal_ptrs;
137     for (const auto& l : arg_literals) {
138       arg_literal_ptrs.push_back(&l);
139     }
140     return Evaluate(computation, arg_literal_ptrs);
141   }
142 
143   // Gets the value of running a single HLO instruction.
144   //
145   // This function may recursively evaluate the dependency of this instruction
146   // within its parent computation until it encounters something that cannot be
147   // evaluated, such as an Infeed or a Parameter instruction.
148   // It makes best effort to partially evaluate a dependency if possible.
149   StatusOr<Literal> Evaluate(
150       HloInstruction* instruction,
151       bool recursively_evaluate_nonconstant_operands = false);
152 
153   // Same as Evaluate, except returning false on error and accepts an output
154   // pointer.
155   bool TryEvaluate(HloInstruction* instruction, Literal* result,
156                    bool recursively_evaluate_nonconstant_operands = false);
157 
158   // Evaluates a single HLO instruction, substituting the given literals for
159   // some of the instruction's operands.
160   //
161   // For example, given instruction = op(A, B, C) and the map
162   // {A = x, C = y}, this evaluates op(x, B, y).
163   StatusOr<Literal> EvaluateWithSubstitutions(
164       const HloInstruction* instruction,
165       const absl::flat_hash_map<const HloInstruction*, const Literal*>&
166           substitutions);
167 
168   StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
169                                                 const Literal& lhs,
170                                                 const Literal& rhs);
171 
172   StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
173                                                const Literal& operand);
174 
175   StatusOr<Literal> EvaluateElementwiseTernaryOp(HloOpcode opcode,
176                                                  const Literal& lhs,
177                                                  const Literal& rhs,
178                                                  const Literal& ehs);
179 
180   StatusOr<Literal> EvaluateElementwiseCompareOp(ComparisonDirection direction,
181                                                  const Literal& lhs,
182                                                  const Literal& rhs);
183 
184   StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
185                                   const PrecisionConfig& precision_config,
186                                   const Literal& lhs, const Literal& rhs);
187 
set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)188   void set_dynamic_dimension_inference(
189       DynamicDimensionInference* dynamic_dimension_inference) {
190     dynamic_dimension_inference_ = dynamic_dimension_inference;
191   }
192 
dynamic_dimension_inference()193   DynamicDimensionInference* dynamic_dimension_inference() {
194     return dynamic_dimension_inference_;
195   }
196 
197   // Enable the fast path for certain operations like dot or convolution.
set_use_fast_path(bool value)198   void set_use_fast_path(bool value) { use_fast_path_ = value; }
199 
200   // Handles evaluation of a custom-call op.
201   // Operand literals are provided in |operands| and implementations must
202   // populate |output| before returning.
203   using CustomCallHandler = std::function<StatusOr<Literal>(
204       HloInstruction* custom_call, absl::Span<const Literal*> operands)>;
205 
206   // Sets a handler that is called during evaluation for custom-call ops.
207   // If no handler is defined the default error behavior will occur. The handler
208   // will be provided evaluated literals for all operands and is expected to
209   // return an output literal of the appropriate shape.
set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)210   void set_custom_call_handler(
211       std::function<StatusOr<Literal>(HloInstruction* custom_call,
212                                       absl::Span<const Literal*> operands)>
213           handler) {
214     custom_call_handler_ = std::move(handler);
215   }
216 
217   // Returns the result of a matrix multiply `lhs x rhs`.
218   static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D(
219       const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs);
220   static std::unique_ptr<Array2D<float>> MatmulArray2D(
221       const Array2D<float>& lhs, const Array2D<float>& rhs);
222   static std::unique_ptr<Array2D<double>> MatmulArray2D(
223       const Array2D<double>& lhs, const Array2D<double>& rhs);
224   static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D(
225       const Array2D<std::complex<float>>& lhs,
226       const Array2D<std::complex<float>>& rhs);
227   static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D(
228       const Array2D<std::complex<double>>& lhs,
229       const Array2D<std::complex<double>>& rhs);
230   static std::unique_ptr<Array2D<int32_t>> MatmulArray2D(
231       const Array2D<int32_t>& lhs, const Array2D<int32_t>& rhs);
232 
233  protected:
234   // Evaluates the given instruction, and stores the evaluation result in the
235   // evaluated_ map.
236   // When a non-empty shape_index is given, the instruction may be partially
237   // evaluated at the given shape_index and the rest of the result could be
238   // marked as undetermined unless it has been previously evaluated using
239   // EvaluateInternal. Such partial evaluation reduces the computation and
240   // memory overhead in cases where we need only one tuple element by avoiding
241   // the evaluation of a full tuple.
242   Status EvaluateInternal(
243       HloInstruction* instruction, const ShapeIndex& shape_index = {},
244       bool recursively_evaluate_nonconstant_operands = false);
245   // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
246   // class.
247   //
248   // A straightforward implementation would be to make it a nested class
249   // declared and defined in hlo_evaluator.cc.  Instead HloEvaluatorTypedVisitor
250   // lives as a separate class with its own header because its template gets
251   // instantiated many times and we want to use extern templates to shard out
252   // the compilation of those instantiations across multiple cc files.
253   template <typename ReturnT, typename ElementwiseT>
254   friend class HloEvaluatorTypedVisitor;
255 
256   // Wraps around instruction handling to infer types before dispatching to
257   // the corresponding typed Visitor.
DefaultAction(HloInstruction * hlo)258   Status DefaultAction(HloInstruction* hlo) override {
259     return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get());
260   }
261 
262   Status Preprocess(HloInstruction* hlo) override;
263 
264   Status Postprocess(HloInstruction* hlo) override;
265 
266   // Operations that are type-agnostic or always return a specific type, such as
267   // HandleIsFinite where boolean is always returned.
268   //
269   Status HandleBitcast(HloInstruction* bitcast) override;
270 
271   Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override;
272 
273   Status HandleSetDimensionSize(HloInstruction* set_dimension_size) override;
274 
275   Status HandleParameter(HloInstruction* parameter) override;
276 
277   Status HandleInfeed(HloInstruction* infeed) override;
278 
279   Status HandleConstant(HloInstruction* constant) override;
280 
281   Status HandleConcatenate(HloInstruction* concatenate) override;
282 
283   Status HandleReshape(HloInstruction* reshape) override;
284 
285   Status HandleTranspose(HloInstruction* transpose) override;
286 
287   Status HandleIsFinite(HloInstruction* is_finite) override;
288 
289   Status HandleCompare(HloInstruction* compare) override;
290 
291   Status HandleTuple(HloInstruction* tuple) override;
292 
293   Status HandleFft(HloInstruction* fft) override;
294 
295   Status HandleGather(HloInstruction* gather) override;
296 
297   Status HandleScatter(HloInstruction* hlo) override;
298 
299   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
300 
301   Status HandleAsyncStart(HloInstruction* async_start) override;
302 
303   Status HandleAsyncUpdate(HloInstruction* async_update) override;
304 
305   Status HandleAsyncDone(HloInstruction* async_done) override;
306 
307   Status HandleCopy(HloInstruction* copy) override;
308 
309   Status HandleCopyStart(HloInstruction* copy_start) override;
310 
311   Status HandleCopyDone(HloInstruction* copy_done) override;
312 
313   Status HandleConditional(HloInstruction* conditional) override;
314 
315   Status HandleCall(HloInstruction* call) override;
316 
317   Status HandleFusion(HloInstruction* fusion) override;
318 
319   Status HandleWhile(HloInstruction* while_hlo) override;
320 
321   Status HandleSelect(HloInstruction* select) override;
322 
323   Status HandleBroadcast(HloInstruction* broadcast) override;
324 
325   Status HandleAfterAll(HloInstruction* after_all) override;
326 
327   Status HandleAddDependency(HloInstruction* add_dependency) override;
328 
329   Status HandleSort(HloInstruction* sort) override;
330 
331   Status HandleReal(HloInstruction* real) override;
332 
333   Status HandleImag(HloInstruction* imag) override;
334 
335   Status HandleComplex(HloInstruction* complex) override;
336 
337   Status HandleReduce(HloInstruction* reduce) override;
338 
339   Status HandleReduceWindow(HloInstruction* hlo) override;
340 
341   Status HandleCustomCall(HloInstruction* custom_call) override;
342 
343   // Unsupported HLOs, note some of them (such as BatchNorm*) are typically
344   // expanded in a semantic-preserving way into other HLOs by adding expansion
345   // HLO pass to the HLO optimization pass during compilation, which can then be
346   // handled by the evaluator.
HandleBatchNormGrad(HloInstruction * batch_norm_grad)347   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
348     return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
349   }
HandleBatchNormInference(HloInstruction * batch_norm_inference)350   Status HandleBatchNormInference(
351       HloInstruction* batch_norm_inference) override {
352     return Unimplemented(
353         "BatchNormInference HLO is unsupported by the evaluator.");
354   }
HandleBatchNormTraining(HloInstruction * batch_norm_training)355   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
356     return Unimplemented(
357         "BatchNormTraining HLO is unsupported by the evaluator.");
358   }
HandleOutfeed(HloInstruction * outfeed)359   Status HandleOutfeed(HloInstruction* outfeed) override {
360     return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
361   }
362 
363   // Returns the already-evaluated literal result for the instruction.
364   //
365   // A Constant instruction is considered evaluated and its literal will be
366   // returned directly without looking up the cache.
367   //
368   // Similarly, a Parameter instruction is considered evaluated and its literal
369   // is looked up in arg_literals.
370   //
371   // Crash with log if the given instruction has not been evaluated previously.
GetEvaluatedLiteralFor(const HloInstruction * hlo)372   const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
373     if (hlo->IsConstant()) {
374       return hlo->literal();
375     }
376     if (hlo->opcode() == HloOpcode::kParameter && !arg_literals_.empty()) {
377       return *arg_literals_.at(hlo->parameter_number());
378     }
379 
380     auto it = evaluated_.find(hlo);
381     CHECK(it != evaluated_.end())
382         << "could not find evaluated value for: " << hlo->ToString();
383     return it->second;
384   }
385 
386   // Returns true if the given hlo has been evaluated and cached.
387   bool IsAlreadyEvaluated(const HloInstruction* hlo,
388                           const ShapeIndex& shape_index = {}) {
389     if (hlo->IsConstant()) {
390       return true;
391     }
392     if (hlo->opcode() == HloOpcode::kParameter && !arg_literals_.empty()) {
393       return true;
394     }
395     auto it = evaluated_.find(hlo);
396     if (it == evaluated_.end()) {
397       return false;
398     }
399     // We may evaluate some elements of a tuple-shaped instruction and mark
400     // the other elements as undetermined. This way we avoid the computation
401     // and memory overhead of evaluating a large tuple when only some elements
402     // are needed. By marking the other elements undetermined, we allow the
403     // evaluator to update the cached tuple literal when more elements are
404     // evaluated.
405     return it->second.IsDetermined(shape_index);
406   }
407 
408   // Tracks the HLO instruction and its evaluated literal result.
409   //
410   // Parameters and constants aren't stored here, see implementation of
411   // GetEvaluatedLiteralFor.
412   //
413   // TODO(b/35950897): have better memory management here to free instructions
414   // that are no longer a parent for any other subsequent instruction in
415   // post-ordering.
416   //
417   // Must be cleared for each evaluation.
418   //
419   // Storing Literal in place requires the container to have pointer stability
420   // so we cannot use flat_hash_map any more.
421   absl::node_hash_map<const HloInstruction*, Literal> evaluated_;
422   // Set by EvaluateInternal and opportunitiscally used by the HandleXXX
423   // functions. When non-empty, the HandleXXX function may evaluate the
424   // instruction at only the given shape index.
425   ShapeIndex visitor_shape_index_;
426   bool enable_partial_evaluation_ = false;
427 
428   // Use fast path that uses eigen in the evaluator.
429   bool use_fast_path_ = false;
430 
431  private:
432   template <typename ReturnT, typename NativeT>
ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)433   static StatusOr<Literal> ElementWiseUnaryOpImpl(
434       HloInstruction* instruction,
435       const std::function<ReturnT(NativeT)>& unary_op,
436       const Literal& operand_literal) {
437     const auto shape = instruction->shape();
438     const auto* operand = instruction->operand(0);
439     TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape()));
440 
441     Literal result(shape);
442     TF_RETURN_IF_ERROR(
443         result.Populate<ReturnT>([&](absl::Span<const int64_t> multi_index) {
444           return unary_op(operand_literal.Get<NativeT>(multi_index));
445         }));
446     return std::move(result);
447   }
448 
449   // Map from a primitive type to its associated (templated) DfsHloVisitor.
450   std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE];
451 
452   // Caches pointers to input literals, assuming they are in post-order.
453   // Literals are not owned by this class, and they must outlive the lifetime of
454   // each invocation to the Evaluate* method.
455   // Must be cleared for each evaluation.
456   std::vector<const Literal*> arg_literals_;
457 
458   // Max loop iterations to execute with no maximum if negative.
459   int64_t max_loop_iterations_ = 0;
460 
461   // Module-level seed handle.
462   uint64_t seed_ = 0;
463   // RNG engine.
464   std::minstd_rand0 engine_;
465 
466   // DynamicDimensionInference is used to evaluate GetDimensionSize, which
467   // returns the dynamic dimension size of its operand.
468   DynamicDimensionInference* dynamic_dimension_inference_ = nullptr;
469 
470   // Optional handler for custom_call ops.
471   std::function<StatusOr<Literal>(HloInstruction* custom_call,
472                                   absl::Span<const Literal*> operands)>
473       custom_call_handler_;
474 
475   HloEvaluator(const HloEvaluator&) = delete;
476   HloEvaluator& operator=(const HloEvaluator&) = delete;
477 };
478 
479 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs,
480                                               const Array2D<float>& rhs);
481 }  // namespace xla
482 
483 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
484