• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "absl/container/node_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/array2d.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/shape_inference.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/macros.h"
37 
38 namespace xla {
39 
40 // Responsible for evaluating HLO and obtain literal as the evaluation results.
41 //
42 // This class is not thread-safe.
43 class HloEvaluator : public DfsHloVisitorWithDefault {
44  public:
45   // Only evaluate up to max_loop_iterations per while-loop execution if
46   // specified.
47   explicit HloEvaluator(int64 max_loop_iterations = -1);
48 
49   // Evaluates an HLO module and an array of pointers to literals.  Returns the
50   // evaluated result as a literal if successful.
51   //
52   // Precondition: The indices of arg_literals correspond to the parameter
53   // numbers of the HLO parameters in the computation. See comment below for an
54   // example.
55   //
56   // (Dummy template arg is to reduce the overloading priority of one overload
57   // so that Evaluate(module, {}) resolves unambiguously.)
Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)58   StatusOr<Literal> Evaluate(const HloModule& module,
59                              absl::Span<const Literal* const> arg_literals) {
60     return Evaluate(*module.entry_computation(), arg_literals);
61   }
62   template <typename Dummy = void>
Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)63   StatusOr<Literal> Evaluate(const HloModule& module,
64                              absl::Span<const Literal> arg_literals) {
65     return Evaluate(*module.entry_computation(), arg_literals);
66   }
67 
68   // Evaluates an HLO computation and an array of pointers to literals.
69   // Returns the evaluated result as a literal if successful.
70   // Precondition: The indices of arg_literals correspond to the parameter
71   // numbers of the HLO parameters in the computation. For e.g., consider the
72   // following graph:
73   //
74   //                *
75   //            /       \
76   //            +     Parameter1
77   //        /      \
78   //       /        \
79   //    Parameter0  Constant
80   //
81   // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
82   // 1 in this computation. The input literals array will then have its first
83   // literal map to Parameter0 and the second map to Parameter1.
84   //
85   // (Dummy template arg is to reduce the overloading priority of one overload
86   // so that Evaluate(module, {}) resolves unambiguously.)
87   StatusOr<Literal> Evaluate(const HloComputation& computation,
88                              absl::Span<const Literal* const> arg_literals);
89   template <typename Dummy = void>
Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)90   StatusOr<Literal> Evaluate(const HloComputation& computation,
91                              absl::Span<const Literal> arg_literals) {
92     std::vector<const Literal*> arg_literal_ptrs;
93     for (const auto& l : arg_literals) {
94       arg_literal_ptrs.push_back(&l);
95     }
96     return Evaluate(computation, arg_literal_ptrs);
97   }
98 
99   // Gets the value of running a single HLO instruction.
100   //
101   // All of the operands to this instruction must be constants.
102   StatusOr<Literal> Evaluate(HloInstruction* instruction);
103 
104   // Same as Evaluate, except returning false on error and accepts an output
105   // pointer.
106   bool TryEvaluate(HloInstruction* instruction, Literal* result);
107 
108   // Evaluates a single HLO instruction, substituting the given literals for
109   // some of the instruction's operands.
110   //
111   // For example, given instruction = op(A, B, C) and the map
112   // {A = x, C = y}, this evaluates op(x, B, y).
113   StatusOr<Literal> EvaluateWithSubstitutions(
114       const HloInstruction* instruction,
115       const std::unordered_map<const HloInstruction*, const Literal*>&
116           substitutions);
117 
118   StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
119                                                 const Literal& lhs,
120                                                 const Literal& rhs);
121 
122   StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
123                                                const Literal& operand);
124 
125   StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
126                                   const PrecisionConfig& precision_config,
127                                   const Literal& lhs, const Literal& rhs);
128 
set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)129   void set_dynamic_dimension_inference(
130       DynamicDimensionInference* dynamic_dimension_inference) {
131     dynamic_dimension_inference_ = dynamic_dimension_inference;
132   }
133 
134   // Enable the fast path for certain operations like dot or convolution.
set_use_fast_path(bool value)135   void set_use_fast_path(bool value) { use_fast_path_ = value; }
136 
137   // Handles evaluation of a custom-call op.
138   // Operand literals are provided in |operands| and implementations must
139   // populate |output| before returning.
140   using CustomCallHandler = std::function<StatusOr<Literal>(
141       HloInstruction* custom_call, absl::Span<const Literal*> operands)>;
142 
143   // Sets a handler that is called during evaluation for custom-call ops.
144   // If no handler is defined the default error behavior will occur. The handler
145   // will be provided evaluated literals for all operands and is expected to
146   // return an output literal of the appropriate shape.
set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)147   void set_custom_call_handler(
148       std::function<StatusOr<Literal>(HloInstruction* custom_call,
149                                       absl::Span<const Literal*> operands)>
150           handler) {
151     custom_call_handler_ = std::move(handler);
152   }
153 
154   // Returns the result of a matrix multiply `lhs x rhs`.
155   static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D(
156       const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs);
157   static std::unique_ptr<Array2D<float>> MatmulArray2D(
158       const Array2D<float>& lhs, const Array2D<float>& rhs);
159   static std::unique_ptr<Array2D<double>> MatmulArray2D(
160       const Array2D<double>& lhs, const Array2D<double>& rhs);
161 
162  protected:
163   // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
164   // class.
165   //
166   // A straightforward implementation would be to make it a nested class
167   // declared and defined in hlo_evaluator.cc.  Instead HloEvaluatorTypedVisitor
168   // lives as a separate class with its own header because its template gets
169   // instantiated many times and we want to use extern templates to shard out
170   // the compilation of those instantiations across multiple cc files.
171   template <typename ReturnT, typename ElementwiseT>
172   friend class HloEvaluatorTypedVisitor;
173 
174   // Wraps around instruction handling to infer types before dispatching to
175   // the corresponding typed Visitor.
DefaultAction(HloInstruction * hlo)176   Status DefaultAction(HloInstruction* hlo) override {
177     return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get());
178   }
179 
180   Status Preprocess(HloInstruction* hlo) override;
181 
182   Status Postprocess(HloInstruction* hlo) override;
183 
184   // Operations that are type-agnostic or always return a specific type, such as
185   // HandleIsFinite where boolean is always returned.
186   //
187   Status HandleBitcast(HloInstruction* bitcast) override;
188 
189   Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override;
190 
191   Status HandleParameter(HloInstruction* parameter) override;
192 
193   Status HandleConstant(HloInstruction* constant) override;
194 
195   Status HandleConcatenate(HloInstruction* concatenate) override;
196 
197   Status HandleReshape(HloInstruction* reshape) override;
198 
199   Status HandleTranspose(HloInstruction* transpose) override;
200 
201   Status HandleIsFinite(HloInstruction* is_finite) override;
202 
203   Status HandleCompare(HloInstruction* compare) override;
204 
205   Status HandleTuple(HloInstruction* tuple) override;
206 
207   Status HandleGather(HloInstruction* gather) override;
208 
209   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
210 
211   Status HandleCopy(HloInstruction* copy) override;
212 
213   Status HandleConditional(HloInstruction* conditional) override;
214 
215   Status HandleCall(HloInstruction* call) override;
216 
217   Status HandleFusion(HloInstruction* fusion) override;
218 
219   Status HandleWhile(HloInstruction* while_hlo) override;
220 
221   Status HandleSelect(HloInstruction* select) override;
222 
223   Status HandleTupleSelect(HloInstruction* tuple_select) override;
224 
225   Status HandleBroadcast(HloInstruction* broadcast) override;
226 
227   Status HandleAfterAll(HloInstruction* after_all) override;
228 
229   Status HandleAddDependency(HloInstruction* add_dependency) override;
230 
231   Status HandleSort(HloInstruction* sort) override;
232 
233   Status HandleReal(HloInstruction* real) override;
234 
235   Status HandleImag(HloInstruction* imag) override;
236 
237   Status HandleComplex(HloInstruction* complex) override;
238 
239   Status HandleReduce(HloInstruction* reduce) override;
240 
241   Status HandleCustomCall(HloInstruction* custom_call) override;
242 
243   // Unsupported HLOs, note some of them (such as BatchNorm*) are typically
244   // expanded in a semantic-preserving way into other HLOs by adding exanpsion
245   // HLO pass to the HLO optimization pass during compilation, which can then be
246   // handled by the evaluator.
HandleBatchNormGrad(HloInstruction * batch_norm_grad)247   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
248     return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
249   };
HandleBatchNormInference(HloInstruction * batch_norm_inference)250   Status HandleBatchNormInference(
251       HloInstruction* batch_norm_inference) override {
252     return Unimplemented(
253         "BatchNormInference HLO is unsupported by the evaluator.");
254   };
HandleBatchNormTraining(HloInstruction * batch_norm_training)255   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
256     return Unimplemented(
257         "BatchNormTraining HLO is unsupported by the evaluator.");
258   };
HandleInfeed(HloInstruction * infeed)259   Status HandleInfeed(HloInstruction* infeed) override {
260     return Unimplemented("Infeed HLO is unsupported by the evaluator.");
261   };
HandleOutfeed(HloInstruction * outfeed)262   Status HandleOutfeed(HloInstruction* outfeed) override {
263     return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
264   };
265 
266   // Returns the already-evaluated literal result for the instruction.
267   //
268   // A Constant instruction is considered evaluated and its literal will be
269   // returned directly without looking up the cache.
270   //
271   // Similarly, a Parameter instruction is considered evaluated and its literal
272   // is looked up in arg_literals.
273   //
274   // Crash with log if the given instruction has not been evaluated previously.
GetEvaluatedLiteralFor(const HloInstruction * hlo)275   const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
276     if (hlo->IsConstant()) {
277       return hlo->literal();
278     }
279     if (hlo->opcode() == HloOpcode::kParameter) {
280       return *arg_literals_.at(hlo->parameter_number());
281     }
282     auto it = evaluated_.find(hlo);
283     CHECK(it != evaluated_.end())
284         << "could not find evaluated value for: " << hlo->ToString();
285     return it->second;
286   }
287 
288   // Tracks the HLO instruction and its evaluated literal result.
289   //
290   // Parameters and constants aren't stored here, see implementation of
291   // GetEvaluatedLiteralFor.
292   //
293   // TODO(b/35950897): have better memory management here to free instructions
294   // that are no longer a parent for any other subsequent instruction in
295   // post-orderring.
296   //
297   // Must be cleared for each evaluation.
298   //
299   // Storing Literal in place requires the container to have pointer stability
300   // so we cannot use flat_hash_map any more.
301   absl::node_hash_map<const HloInstruction*, Literal> evaluated_;
302 
303   // Use fast path that uses eigen in the evaluator.
304   bool use_fast_path_ = false;
305 
306  private:
307   template <typename ReturnT, typename NativeT>
ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)308   static StatusOr<Literal> ElementWiseUnaryOpImpl(
309       HloInstruction* instruction,
310       const std::function<ReturnT(NativeT)>& unary_op,
311       const Literal& operand_literal) {
312     const auto shape = instruction->shape();
313     const auto* operand = instruction->operand(0);
314     TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape()));
315 
316     Literal result(shape);
317     TF_RETURN_IF_ERROR(
318         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
319           return unary_op(operand_literal.Get<NativeT>(multi_index));
320         }));
321     return std::move(result);
322   }
323 
324   // Map from a primitive type to its associated (templated) DfsHloVisitor.
325   std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE];
326 
327   // Caches pointers to input literals, assuming they are in post-order.
328   // Literals are not owned by this class, and they must outlive the lifetime of
329   // each invocation to the Evaluate* method.
330   // Must be cleared for each evaluation.
331   std::vector<const Literal*> arg_literals_;
332 
333   // Max loop iterations to execute with no maximum if negative.
334   int64 max_loop_iterations_ = 0;
335 
336   // Module-level seed handle.
337   uint64 seed_ = 0;
338   // RNG engine.
339   std::minstd_rand0 engine_;
340 
341   // DynamicDimensionInference is used to evaluate GetDimensionSize, which
342   // returns the dynamic dimension size of its operand.
343   DynamicDimensionInference* dynamic_dimension_inference_ = nullptr;
344 
345   // Optional handler for custom_call ops.
346   std::function<StatusOr<Literal>(HloInstruction* custom_call,
347                                   absl::Span<const Literal*> operands)>
348       custom_call_handler_;
349 
350   TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
351 };
352 
353 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs,
354                                               const Array2D<float>& rhs);
355 }  // namespace xla
356 
357 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
358