• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
18 
19 #define _USE_MATH_DEFINES
20 
21 #include <functional>
22 #include <memory>
23 
24 #include "absl/container/node_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
30 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/shape_inference.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/macros.h"
39 
40 namespace xla {
41 
42 // Responsible for evaluating HLO and obtain literal as the evaluation results.
43 //
44 // This class is not thread-safe.
45 class HloEvaluator : public DfsHloVisitorWithDefault {
46  public:
47   // Only evaluate up to max_loop_iterations per while-loop execution if
48   // specified.
49   explicit HloEvaluator(int64_t max_loop_iterations = -1);
50 
51   // Evaluates an HLO module and an array of pointers to literals.  Returns the
52   // evaluated result as a literal if successful.
53   //
54   // Precondition: The indices of arg_literals correspond to the parameter
55   // numbers of the HLO parameters in the computation. See comment below for an
56   // example.
57   //
58   // (Dummy template arg is to reduce the overloading priority of one overload
59   // so that Evaluate(module, {}) resolves unambiguously.)
Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)60   StatusOr<Literal> Evaluate(const HloModule& module,
61                              absl::Span<const Literal* const> arg_literals) {
62     return Evaluate(*module.entry_computation(), arg_literals);
63   }
64   template <typename Dummy = void>
Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)65   StatusOr<Literal> Evaluate(const HloModule& module,
66                              absl::Span<const Literal> arg_literals) {
67     return Evaluate(*module.entry_computation(), arg_literals);
68   }
69 
70   // Evaluates an HLO computation and an array of pointers to literals.
71   // Returns the evaluated result as a literal if successful.
72   // Precondition: The indices of arg_literals correspond to the parameter
73   // numbers of the HLO parameters in the computation. For e.g., consider the
74   // following graph:
75   //
76   //                *
77   //            /       \
78   //            +     Parameter1
79   //        /      \
80   //       /        \
81   //    Parameter0  Constant
82   //
83   // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
84   // 1 in this computation. The input literals array will then have its first
85   // literal map to Parameter0 and the second map to Parameter1.
86   //
87   // (Dummy template arg is to reduce the overloading priority of one overload
88   // so that Evaluate(module, {}) resolves unambiguously.)
89   StatusOr<Literal> Evaluate(const HloComputation& computation,
90                              absl::Span<const Literal* const> arg_literals);
91   template <typename Dummy = void>
Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)92   StatusOr<Literal> Evaluate(const HloComputation& computation,
93                              absl::Span<const Literal> arg_literals) {
94     std::vector<const Literal*> arg_literal_ptrs;
95     for (const auto& l : arg_literals) {
96       arg_literal_ptrs.push_back(&l);
97     }
98     return Evaluate(computation, arg_literal_ptrs);
99   }
100 
101   // Gets the value of running a single HLO instruction.
102   //
103   // All of the operands to this instruction must be constants.
104   StatusOr<Literal> Evaluate(HloInstruction* instruction);
105 
106   // Same as Evaluate, except returning false on error and accepts an output
107   // pointer.
108   bool TryEvaluate(HloInstruction* instruction, Literal* result);
109 
110   // Evaluates a single HLO instruction, substituting the given literals for
111   // some of the instruction's operands.
112   //
113   // For example, given instruction = op(A, B, C) and the map
114   // {A = x, C = y}, this evaluates op(x, B, y).
115   StatusOr<Literal> EvaluateWithSubstitutions(
116       const HloInstruction* instruction,
117       const std::unordered_map<const HloInstruction*, const Literal*>&
118           substitutions);
119 
120   StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
121                                                 const Literal& lhs,
122                                                 const Literal& rhs);
123 
124   StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
125                                                const Literal& operand);
126 
127   StatusOr<Literal> EvaluateElementwiseTernaryOp(HloOpcode opcode,
128                                                  const Literal& lhs,
129                                                  const Literal& rhs,
130                                                  const Literal& ehs);
131 
132   StatusOr<Literal> EvaluateElementwiseCompareOp(ComparisonDirection direction,
133                                                  const Literal& lhs,
134                                                  const Literal& rhs);
135 
136   StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
137                                   const PrecisionConfig& precision_config,
138                                   const Literal& lhs, const Literal& rhs);
139 
set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)140   void set_dynamic_dimension_inference(
141       DynamicDimensionInference* dynamic_dimension_inference) {
142     dynamic_dimension_inference_ = dynamic_dimension_inference;
143   }
144 
dynamic_dimension_inference()145   DynamicDimensionInference* dynamic_dimension_inference() {
146     return dynamic_dimension_inference_;
147   }
148 
149   // Enable the fast path for certain operations like dot or convolution.
set_use_fast_path(bool value)150   void set_use_fast_path(bool value) { use_fast_path_ = value; }
151 
152   // Handles evaluation of a custom-call op.
153   // Operand literals are provided in |operands| and implementations must
154   // populate |output| before returning.
155   using CustomCallHandler = std::function<StatusOr<Literal>(
156       HloInstruction* custom_call, absl::Span<const Literal*> operands)>;
157 
158   // Sets a handler that is called during evaluation for custom-call ops.
159   // If no handler is defined the default error behavior will occur. The handler
160   // will be provided evaluated literals for all operands and is expected to
161   // return an output literal of the appropriate shape.
set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)162   void set_custom_call_handler(
163       std::function<StatusOr<Literal>(HloInstruction* custom_call,
164                                       absl::Span<const Literal*> operands)>
165           handler) {
166     custom_call_handler_ = std::move(handler);
167   }
168 
169   // Returns the result of a matrix multiply `lhs x rhs`.
170   static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D(
171       const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs);
172   static std::unique_ptr<Array2D<float>> MatmulArray2D(
173       const Array2D<float>& lhs, const Array2D<float>& rhs);
174   static std::unique_ptr<Array2D<double>> MatmulArray2D(
175       const Array2D<double>& lhs, const Array2D<double>& rhs);
176   static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D(
177       const Array2D<std::complex<float>>& lhs,
178       const Array2D<std::complex<float>>& rhs);
179   static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D(
180       const Array2D<std::complex<double>>& lhs,
181       const Array2D<std::complex<double>>& rhs);
182   static std::unique_ptr<Array2D<int32>> MatmulArray2D(
183       const Array2D<int32>& lhs, const Array2D<int32>& rhs);
184 
185  protected:
186   // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
187   // class.
188   //
189   // A straightforward implementation would be to make it a nested class
190   // declared and defined in hlo_evaluator.cc.  Instead HloEvaluatorTypedVisitor
191   // lives as a separate class with its own header because its template gets
192   // instantiated many times and we want to use extern templates to shard out
193   // the compilation of those instantiations across multiple cc files.
194   template <typename ReturnT, typename ElementwiseT>
195   friend class HloEvaluatorTypedVisitor;
196 
197   // Wraps around instruction handling to infer types before dispatching to
198   // the corresponding typed Visitor.
DefaultAction(HloInstruction * hlo)199   Status DefaultAction(HloInstruction* hlo) override {
200     return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get());
201   }
202 
203   Status Preprocess(HloInstruction* hlo) override;
204 
205   Status Postprocess(HloInstruction* hlo) override;
206 
207   // Operations that are type-agnostic or always return a specific type, such as
208   // HandleIsFinite where boolean is always returned.
209   //
210   Status HandleBitcast(HloInstruction* bitcast) override;
211 
212   Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override;
213 
214   Status HandleSetDimensionSize(HloInstruction* set_dimension_size) override;
215 
216   Status HandleParameter(HloInstruction* parameter) override;
217 
218   Status HandleConstant(HloInstruction* constant) override;
219 
220   Status HandleConcatenate(HloInstruction* concatenate) override;
221 
222   Status HandleReshape(HloInstruction* reshape) override;
223 
224   Status HandleTranspose(HloInstruction* transpose) override;
225 
226   Status HandleIsFinite(HloInstruction* is_finite) override;
227 
228   Status HandleCompare(HloInstruction* compare) override;
229 
230   Status HandleTuple(HloInstruction* tuple) override;
231 
232   Status HandleFft(HloInstruction* fft) override;
233 
234   Status HandleGather(HloInstruction* gather) override;
235 
236   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
237 
238   Status HandleCopy(HloInstruction* copy) override;
239 
240   Status HandleCopyStart(HloInstruction* copy_start) override;
241 
242   Status HandleCopyDone(HloInstruction* copy_done) override;
243 
244   Status HandleConditional(HloInstruction* conditional) override;
245 
246   Status HandleCall(HloInstruction* call) override;
247 
248   Status HandleFusion(HloInstruction* fusion) override;
249 
250   Status HandleWhile(HloInstruction* while_hlo) override;
251 
252   Status HandleSelect(HloInstruction* select) override;
253 
254   Status HandleTupleSelect(HloInstruction* tuple_select) override;
255 
256   Status HandleBroadcast(HloInstruction* broadcast) override;
257 
258   Status HandleAfterAll(HloInstruction* after_all) override;
259 
260   Status HandleAddDependency(HloInstruction* add_dependency) override;
261 
262   Status HandleSort(HloInstruction* sort) override;
263 
264   Status HandleReal(HloInstruction* real) override;
265 
266   Status HandleImag(HloInstruction* imag) override;
267 
268   Status HandleComplex(HloInstruction* complex) override;
269 
270   Status HandleReduce(HloInstruction* reduce) override;
271 
272   Status HandleReduceWindow(HloInstruction* hlo) override;
273 
274   Status HandleCustomCall(HloInstruction* custom_call) override;
275 
276   // Unsupported HLOs, note some of them (such as BatchNorm*) are typically
277   // expanded in a semantic-preserving way into other HLOs by adding expansion
278   // HLO pass to the HLO optimization pass during compilation, which can then be
279   // handled by the evaluator.
HandleBatchNormGrad(HloInstruction * batch_norm_grad)280   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
281     return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
282   };
HandleBatchNormInference(HloInstruction * batch_norm_inference)283   Status HandleBatchNormInference(
284       HloInstruction* batch_norm_inference) override {
285     return Unimplemented(
286         "BatchNormInference HLO is unsupported by the evaluator.");
287   };
HandleBatchNormTraining(HloInstruction * batch_norm_training)288   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
289     return Unimplemented(
290         "BatchNormTraining HLO is unsupported by the evaluator.");
291   };
HandleInfeed(HloInstruction * infeed)292   Status HandleInfeed(HloInstruction* infeed) override {
293     return Unimplemented("Infeed HLO is unsupported by the evaluator.");
294   };
HandleOutfeed(HloInstruction * outfeed)295   Status HandleOutfeed(HloInstruction* outfeed) override {
296     return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
297   };
298 
299   // Returns the already-evaluated literal result for the instruction.
300   //
301   // A Constant instruction is considered evaluated and its literal will be
302   // returned directly without looking up the cache.
303   //
304   // Similarly, a Parameter instruction is considered evaluated and its literal
305   // is looked up in arg_literals.
306   //
307   // Crash with log if the given instruction has not been evaluated previously.
GetEvaluatedLiteralFor(const HloInstruction * hlo)308   const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
309     if (hlo->IsConstant()) {
310       return hlo->literal();
311     }
312     if (hlo->opcode() == HloOpcode::kParameter) {
313       return *arg_literals_.at(hlo->parameter_number());
314     }
315     auto it = evaluated_.find(hlo);
316     CHECK(it != evaluated_.end())
317         << "could not find evaluated value for: " << hlo->ToString();
318     return it->second;
319   }
320 
321   // Tracks the HLO instruction and its evaluated literal result.
322   //
323   // Parameters and constants aren't stored here, see implementation of
324   // GetEvaluatedLiteralFor.
325   //
326   // TODO(b/35950897): have better memory management here to free instructions
327   // that are no longer a parent for any other subsequent instruction in
328   // post-ordering.
329   //
330   // Must be cleared for each evaluation.
331   //
332   // Storing Literal in place requires the container to have pointer stability
333   // so we cannot use flat_hash_map any more.
334   absl::node_hash_map<const HloInstruction*, Literal> evaluated_;
335 
336   // Use fast path that uses eigen in the evaluator.
337   bool use_fast_path_ = false;
338 
339  private:
340   template <typename ReturnT, typename NativeT>
ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)341   static StatusOr<Literal> ElementWiseUnaryOpImpl(
342       HloInstruction* instruction,
343       const std::function<ReturnT(NativeT)>& unary_op,
344       const Literal& operand_literal) {
345     const auto shape = instruction->shape();
346     const auto* operand = instruction->operand(0);
347     TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape()));
348 
349     Literal result(shape);
350     TF_RETURN_IF_ERROR(
351         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
352           return unary_op(operand_literal.Get<NativeT>(multi_index));
353         }));
354     return std::move(result);
355   }
356 
357   // Map from a primitive type to its associated (templated) DfsHloVisitor.
358   std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE];
359 
360   // Caches pointers to input literals, assuming they are in post-order.
361   // Literals are not owned by this class, and they must outlive the lifetime of
362   // each invocation to the Evaluate* method.
363   // Must be cleared for each evaluation.
364   std::vector<const Literal*> arg_literals_;
365 
366   // Max loop iterations to execute with no maximum if negative.
367   int64 max_loop_iterations_ = 0;
368 
369   // Module-level seed handle.
370   uint64 seed_ = 0;
371   // RNG engine.
372   std::minstd_rand0 engine_;
373 
374   // DynamicDimensionInference is used to evaluate GetDimensionSize, which
375   // returns the dynamic dimension size of its operand.
376   DynamicDimensionInference* dynamic_dimension_inference_ = nullptr;
377 
378   // Optional handler for custom_call ops.
379   std::function<StatusOr<Literal>(HloInstruction* custom_call,
380                                   absl::Span<const Literal*> operands)>
381       custom_call_handler_;
382 
383   TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
384 };
385 
386 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs,
387                                               const Array2D<float>& rhs);
388 }  // namespace xla
389 
390 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
391