• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdlib>
20 #include <functional>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/compiler/xla/index_util.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/ptr_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_query.h"
35 #include "tensorflow/compiler/xla/service/shape_inference.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/window_util.h"
40 #include "tensorflow/core/lib/core/bitmap.h"
41 #include "tensorflow/core/lib/core/casts.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/optional.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace xla {
51 
52 namespace {
53 
54 template <typename T>
55 struct is_complex_t : public std::false_type {};
56 
57 template <>
58 struct is_complex_t<complex64> : public std::true_type {};
59 
60 template <typename OperandT>
Compare(const Shape & shape,HloOpcode opcode,const Literal & lhs_literal,const Literal & rhs_literal)61 StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
62                                            const Literal& lhs_literal,
63                                            const Literal& rhs_literal) {
64   std::function<bool(OperandT, OperandT)> compare_op;
65   switch (opcode) {
66     case HloOpcode::kEq:
67       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
68         return lhs_el == rhs_el;
69       };
70       break;
71     case HloOpcode::kNe:
72       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
73         return lhs_el != rhs_el;
74       };
75       break;
76     case HloOpcode::kGe:
77       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
78         return lhs_el >= rhs_el;
79       };
80       break;
81     case HloOpcode::kGt:
82       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
83         return lhs_el > rhs_el;
84       };
85       break;
86     case HloOpcode::kLe:
87       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
88         return lhs_el <= rhs_el;
89       };
90       break;
91     case HloOpcode::kLt:
92       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
93         return lhs_el < rhs_el;
94       };
95       break;
96     default:
97       LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
98                  << HloOpcodeString(opcode);
99   }
100 
101   auto result = Literal::CreateFromShape(shape);
102   TF_RETURN_IF_ERROR(result->Populate<bool>(
103       [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
104         return compare_op(lhs_literal.Get<OperandT>(multi_index),
105                           rhs_literal.Get<OperandT>(multi_index));
106       }));
107 
108   return std::move(result);
109 }
110 
111 template <>
Compare(const Shape & shape,HloOpcode opcode,const Literal & lhs_literal,const Literal & rhs_literal)112 StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
113     const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
114     const Literal& rhs_literal) {
115   std::function<bool(complex64, complex64)> compare_op;
116   switch (opcode) {
117     case HloOpcode::kEq:
118       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
119         return lhs_el == rhs_el;
120       };
121       break;
122     case HloOpcode::kNe:
123       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
124         return lhs_el != rhs_el;
125       };
126       break;
127     default:
128       LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
129                  << HloOpcodeString(opcode);
130   }
131 
132   auto result = Literal::CreateFromShape(shape);
133   TF_RETURN_IF_ERROR(result->Populate<bool>(
134       [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
135         return compare_op(lhs_literal.Get<complex64>(multi_index),
136                           rhs_literal.Get<complex64>(multi_index));
137       }));
138 
139   return std::move(result);
140 }
141 
142 template <typename ReturnT, typename NativeT>
ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)143 StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
144     HloInstruction* instruction,
145     const std::function<ReturnT(NativeT)>& unary_op,
146     const Literal& operand_literal) {
147   const auto shape = instruction->shape();
148   const auto* operand = instruction->operand(0);
149 
150   // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
151   // removed.
152   if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
153     return Unimplemented(
154         "Implicit broadcasting is currently unsupported in HLO evaluator "
155         "Shape Mismatch: %s vs %s",
156         ShapeUtil::HumanString(shape).c_str(),
157         ShapeUtil::HumanString(operand->shape()).c_str());
158   }
159 
160   auto result = Literal::CreateFromShape(shape);
161 
162   TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
163       [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
164         return unary_op(operand_literal.Get<NativeT>(multi_index));
165       }));
166   return std::move(result);
167 }
168 
169 // For one particular placement of a window in a base shape (the placement is
170 // represented as `window_count_index`), iterates inside the window. Translates
171 // the window index into base index. If the base index is within bound, call `f`
172 // with the base index.
IterateThroughWindow(const Shape & window_shape,const Window & window,const Shape & base_shape,const tensorflow::gtl::ArraySlice<int64> & window_count_index,const std::function<void (const std::vector<int64> &)> & f)173 void IterateThroughWindow(
174     const Shape& window_shape, const Window& window, const Shape& base_shape,
175     const tensorflow::gtl::ArraySlice<int64>& window_count_index,
176     const std::function<void(const std::vector<int64>&)>& f) {
177   const int64 rank = ShapeUtil::Rank(base_shape);
178   DimensionVector window_index(rank);
179   std::fill(window_index.begin(), window_index.end(), 0);
180   do {
181     std::vector<int64> base_index(rank);
182     bool out_of_bound = false;
183     for (int64 i = 0; i < rank; ++i) {
184       base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
185                       window_index[i] - window.dimensions(i).padding_low();
186       if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
187         out_of_bound = true;
188         break;
189       }
190     }
191     if (!out_of_bound) {
192       f(base_index);
193     }
194   } while (IndexUtil::BumpIndices(window_shape, &window_index));
195 }
196 
197 }  // namespace
198 
199 template <typename ReturnT, typename ElementwiseT>
200 class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
201  public:
TypedVisitor(HloEvaluator * p)202   explicit TypedVisitor(HloEvaluator* p) : parent_(p) {}
203 
204   // The following higher-order functions convert a function with ElementwiseT
205   // to a function with ReturnT.
ConvertUnaryFunction(const std::function<ElementwiseT (ElementwiseT)> & unary_op)206   std::function<ReturnT(ReturnT)> ConvertUnaryFunction(
207       const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
208     return [&unary_op](ReturnT arg) {
209       return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg)));
210     };
211   }
ConvertBinaryFunction(const std::function<ElementwiseT (ElementwiseT,ElementwiseT)> & binary_op)212   std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction(
213       const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
214           binary_op) {
215     return [&binary_op](ReturnT arg1, ReturnT arg2) {
216       return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1),
217                                             static_cast<ElementwiseT>(arg2)));
218     };
219   }
ConvertTernaryFunction(const std::function<ElementwiseT (ElementwiseT,ElementwiseT,ElementwiseT)> & ternary_op)220   std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction(
221       const std::function<ElementwiseT(ElementwiseT, ElementwiseT,
222                                        ElementwiseT)>& ternary_op) {
223     return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) {
224       return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1),
225                                              static_cast<ElementwiseT>(arg2),
226                                              static_cast<ElementwiseT>(arg3)));
227     };
228   }
229 
DefaultAction(HloInstruction * hlo_instruction)230   Status DefaultAction(HloInstruction* hlo_instruction) override {
231     return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
232                          HloOpcodeString(hlo_instruction->opcode()).c_str());
233   }
234 
235   // TODO(b/35950897): many of the stl functions used in the handlers are not
236   // overloaded for every XLA primitive types.
237 
238   template <typename NativeT,
239             typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
240                 nullptr>
HandleAbs(HloInstruction * abs)241   Status HandleAbs(HloInstruction* abs) {
242     TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
243                         ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
244                           return elem_operand;
245                         }));
246     return Status::OK();
247   }
248 
249   template <
250       typename NativeT,
251       typename std::enable_if<std::is_signed<NativeT>::value ||
252                               is_complex_t<NativeT>::value>::type* = nullptr>
HandleAbs(HloInstruction * abs)253   Status HandleAbs(HloInstruction* abs) {
254     TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
255                         ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) {
256                           return std::abs(elem_operand);
257                         }));
258     return Status::OK();
259   }
260 
HandleAbs(HloInstruction * abs)261   Status HandleAbs(HloInstruction* abs) override {
262     return HandleAbs<ElementwiseT>(abs);
263   }
264 
265   template <
266       typename NativeT,
267       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleRound(HloInstruction * round)268   Status HandleRound(HloInstruction* round) {
269     TF_ASSIGN_OR_RETURN(
270         parent_->evaluated_[round],
271         ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) {
272           return std::round(elem_operand);
273         }));
274     return Status::OK();
275   }
276 
277   template <
278       typename NativeT,
279       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleRound(HloInstruction * round)280   Status HandleRound(HloInstruction* round) {
281     return InvalidArgument("Unsupported type for Round");
282   }
283 
HandleRound(HloInstruction * round)284   Status HandleRound(HloInstruction* round) override {
285     return HandleRound<ReturnT>(round);
286   }
287 
HandleBroadcast(HloInstruction * broadcast)288   Status HandleBroadcast(HloInstruction* broadcast) override {
289     parent_->evaluated_[broadcast] =
290         Literal::CreateFromShape(broadcast->shape());
291     auto output = parent_->evaluated_[broadcast].get();
292     const Literal& operand_to_broadcast =
293         parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
294     std::vector<int64> broadcast_indices(
295         ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
296 
297     TF_RET_CHECK(broadcast->dimensions().size() ==
298                  ShapeUtil::Rank(operand_to_broadcast.shape()))
299         << "broadcast dimensions is of size: " << broadcast->dimensions().size()
300         << " and rank of operand_to_broadcast is: "
301         << ShapeUtil::Rank(operand_to_broadcast.shape());
302     // Checks that operand's dimensions are the same as the broadcast's
303     // dimensions along the dimensions to be broadcasted.
304     for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
305       TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
306                    operand_to_broadcast.shape().dimensions(i));
307     }
308 
309     return output->Populate<ReturnT>(
310         [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
311           for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
312             broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
313           }
314           return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
315         });
316   }
317 
318   template <
319       typename NativeT,
320       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleCeil(HloInstruction * ceil)321   Status HandleCeil(HloInstruction* ceil) {
322     TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
323                         ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) {
324                           return std::ceil(elem_operand);
325                         }));
326     return Status::OK();
327   }
328 
329   template <
330       typename NativeT,
331       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleCeil(HloInstruction * ceil)332   Status HandleCeil(HloInstruction* ceil) {
333     return InvalidArgument("Unsupported type for Ceil");
334   }
335 
HandleCeil(HloInstruction * ceil)336   Status HandleCeil(HloInstruction* ceil) override {
337     return HandleCeil<ReturnT>(ceil);
338   }
339 
HandleConvert(HloInstruction * convert)340   Status HandleConvert(HloInstruction* convert) override {
341     const HloInstruction* operand = convert->operand(0);
342     TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
343     TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
344                         parent_->GetEvaluatedLiteralFor(operand).Convert(
345                             convert->shape().element_type()));
346 
347     if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
348       parent_->evaluated_[convert] = std::move(result);
349     } else {
350       parent_->evaluated_[convert] =
351           result->Relayout(convert->shape().layout());
352     }
353     return Status::OK();
354   }
355 
HandleExp(HloInstruction * exp)356   Status HandleExp(HloInstruction* exp) override {
357     TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
358                         ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) {
359                           return std::exp(elem_operand);
360                         }));
361     return Status::OK();
362   }
363 
364   template <
365       typename NativeT,
366       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleFloor(HloInstruction * floor)367   Status HandleFloor(HloInstruction* floor) {
368     TF_ASSIGN_OR_RETURN(
369         parent_->evaluated_[floor],
370         ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) {
371           return std::floor(elem_operand);
372         }));
373     return Status::OK();
374   }
375 
376   template <
377       typename NativeT,
378       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleFloor(HloInstruction * floor)379   Status HandleFloor(HloInstruction* floor) {
380     return InvalidArgument("Unsupported type for Floor");
381   }
382 
HandleFloor(HloInstruction * floor)383   Status HandleFloor(HloInstruction* floor) override {
384     return HandleFloor<ReturnT>(floor);
385   }
386 
HandleLog(HloInstruction * log)387   Status HandleLog(HloInstruction* log) override {
388     TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
389                         ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
390                           return std::log(elem_operand);
391                         }));
392     return Status::OK();
393   }
394 
395   template <typename NativeT,
396             typename std::enable_if<
397                 std::is_integral<NativeT>::value &&
398                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleNot(HloInstruction * not_)399   Status HandleNot(HloInstruction* not_) {
400     TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
401                         ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
402                           return ~elem_operand;
403                         }));
404     return Status::OK();
405   }
406 
407   template <typename NativeT, typename std::enable_if<std::is_floating_point<
408                                   NativeT>::value>::type* = nullptr>
HandleNot(HloInstruction * not_)409   Status HandleNot(HloInstruction* not_) {
410     TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
411                         ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
412                           return !elem_operand;
413                         }));
414     return Status::OK();
415   }
416 
417   template <typename NativeT,
418             typename std::enable_if<std::is_same<NativeT, bool>::value>::type* =
419                 nullptr>
HandleNot(HloInstruction * not_)420   Status HandleNot(HloInstruction* not_) {
421     TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
422                         ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
423                           return !elem_operand;
424                         }));
425     return Status::OK();
426   }
427 
428   template <
429       typename NativeT,
430       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleNot(HloInstruction * not_)431   Status HandleNot(HloInstruction* not_) {
432     return InvalidArgument("Unsupported type for Not");
433   }
434 
HandleNot(HloInstruction * not_)435   Status HandleNot(HloInstruction* not_) override {
436     return HandleNot<ElementwiseT>(not_);
437   }
438 
439   template <typename NativeT,
440             typename std::enable_if<
441                 std::is_signed<NativeT>::value &&
442                 !std::is_floating_point<NativeT>::value>::type* = nullptr>
HandleNegate(HloInstruction * negate)443   Status HandleNegate(HloInstruction* negate) {
444     using type = typename std::make_unsigned<NativeT>::type;
445     TF_ASSIGN_OR_RETURN(
446         parent_->evaluated_[negate],
447         ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) {
448           return NativeT(-type(elem_operand));
449         }));
450     return Status::OK();
451   }
452 
453   template <typename NativeT,
454             typename std::enable_if<
455                 !std::is_signed<NativeT>::value ||
456                 std::is_floating_point<NativeT>::value>::type* = nullptr>
HandleNegate(HloInstruction * negate)457   Status HandleNegate(HloInstruction* negate) {
458     TF_ASSIGN_OR_RETURN(
459         parent_->evaluated_[negate],
460         ElementWiseUnaryOp(
461             negate, [](ElementwiseT elem_operand) { return -elem_operand; }));
462     return Status::OK();
463   }
464 
HandleNegate(HloInstruction * negate)465   Status HandleNegate(HloInstruction* negate) override {
466     return HandleNegate<ReturnT>(negate);
467   }
468 
469   template <
470       typename NativeT,
471       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleSign(HloInstruction * sign)472   Status HandleSign(HloInstruction* sign) {
473     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
474                         ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
475                           return (ElementwiseT(0) < elem_operand) -
476                                  (elem_operand < ElementwiseT(0));
477                         }));
478     return Status::OK();
479   }
480 
481   template <
482       typename NativeT,
483       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleSign(HloInstruction * sign)484   Status HandleSign(HloInstruction* sign) {
485     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
486                         ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
487                           auto abs_val = std::abs(elem_operand);
488                           return 0 == abs_val ? ElementwiseT(0)
489                                               : elem_operand / abs_val;
490                         }));
491     return Status::OK();
492   }
493 
HandleSign(HloInstruction * sign)494   Status HandleSign(HloInstruction* sign) override {
495     return HandleSign<ReturnT>(sign);
496   }
497 
498   template <typename NativeT, typename std::enable_if<std::is_floating_point<
499                                   NativeT>::value>::type* = nullptr>
HandleAtan2(HloInstruction * atan2)500   Status HandleAtan2(HloInstruction* atan2) {
501     TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2],
502                         ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem,
503                                                       ElementwiseT rhs_elem) {
504                           return std::atan2(lhs_elem, rhs_elem);
505                         }));
506     return Status::OK();
507   }
508 
509   template <typename NativeT, typename std::enable_if<!std::is_floating_point<
510                                   NativeT>::value>::type* = nullptr>
HandleAtan2(HloInstruction * atan2)511   Status HandleAtan2(HloInstruction* atan2) {
512     return InvalidArgument("Unsupported type for Atan2");
513   }
514 
HandleAtan2(HloInstruction * atan2)515   Status HandleAtan2(HloInstruction* atan2) override {
516     return HandleAtan2<ElementwiseT>(atan2);
517   }
518 
HandleTanh(HloInstruction * tanh)519   Status HandleTanh(HloInstruction* tanh) override {
520     TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
521                         ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) {
522                           return std::tanh(elem_operand);
523                         }));
524     return Status::OK();
525   }
526 
527   template <typename NativeT,
528             typename std::enable_if<
529                 std::is_signed<NativeT>::value &&
530                 !std::is_floating_point<NativeT>::value>::type* = nullptr>
HandleMultiply(HloInstruction * multiply)531   Status HandleMultiply(HloInstruction* multiply) {
532     using type = typename std::make_unsigned<NativeT>::type;
533     TF_ASSIGN_OR_RETURN(
534         parent_->evaluated_[multiply],
535         ElementWiseBinaryOp(multiply,
536                             [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
537                               return NativeT(type(lhs_elem) * type(rhs_elem));
538                             }));
539     return Status::OK();
540   }
541 
542   template <
543       typename NativeT,
544       typename std::enable_if<std::is_unsigned<NativeT>::value ||
545                               std::is_floating_point<NativeT>::value ||
546                               is_complex_t<NativeT>::value>::type* = nullptr>
HandleMultiply(HloInstruction * multiply)547   Status HandleMultiply(HloInstruction* multiply) {
548     TF_ASSIGN_OR_RETURN(
549         parent_->evaluated_[multiply],
550         ElementWiseBinaryOp(multiply,
551                             [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
552                               return lhs_elem * rhs_elem;
553                             }));
554     return Status::OK();
555   }
556 
HandleMultiply(HloInstruction * multiply)557   Status HandleMultiply(HloInstruction* multiply) override {
558     return HandleMultiply<ElementwiseT>(multiply);
559   }
560 
HandleSubtract(HloInstruction * subtract)561   Status HandleSubtract(HloInstruction* subtract) override {
562     TF_ASSIGN_OR_RETURN(
563         parent_->evaluated_[subtract],
564         ElementWiseBinaryOp(subtract,
565                             [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
566                               return lhs_elem - rhs_elem;
567                             }));
568     return Status::OK();
569   }
570 
HandleAdd(HloInstruction * add)571   Status HandleAdd(HloInstruction* add) override {
572     TF_ASSIGN_OR_RETURN(parent_->evaluated_[add],
573                         ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem,
574                                                     ElementwiseT rhs_elem) {
575                           return lhs_elem + rhs_elem;
576                         }));
577     return Status::OK();
578   }
579 
HandleDivide(HloInstruction * divide)580   Status HandleDivide(HloInstruction* divide) override {
581     TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
582                         ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
583                                                        ElementwiseT rhs_elem) {
584                           return lhs_elem / rhs_elem;
585                         }));
586     return Status::OK();
587   }
588 
589   template <
590       typename NativeT,
591       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleMaximum(HloInstruction * maximum)592   Status HandleMaximum(HloInstruction* maximum) {
593     TF_ASSIGN_OR_RETURN(
594         parent_->evaluated_[maximum],
595         ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
596           return std::fmax(lhs, rhs);
597         }));
598     return Status::OK();
599   }
600 
601   template <
602       typename NativeT,
603       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleMaximum(HloInstruction * maximum)604   Status HandleMaximum(HloInstruction* maximum) {
605     return InvalidArgument("Unsupported type for Maximum");
606   }
607 
HandleMaximum(HloInstruction * maximum)608   Status HandleMaximum(HloInstruction* maximum) override {
609     return HandleMaximum<ElementwiseT>(maximum);
610   }
611 
612   template <
613       typename NativeT,
614       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleMinimum(HloInstruction * minimum)615   Status HandleMinimum(HloInstruction* minimum) {
616     TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
617                         ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
618                                                         ElementwiseT rhs_el) {
619                           return std::fmin(lhs_el, rhs_el);
620                         }));
621     return Status::OK();
622   }
623 
624   template <
625       typename NativeT,
626       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleMinimum(HloInstruction * minimum)627   Status HandleMinimum(HloInstruction* minimum) {
628     return InvalidArgument("Unsupported type for Minimum");
629   }
630 
HandleMinimum(HloInstruction * minimum)631   Status HandleMinimum(HloInstruction* minimum) override {
632     return HandleMinimum<ElementwiseT>(minimum);
633   }
634 
HandlePower(HloInstruction * power)635   Status HandlePower(HloInstruction* power) override {
636     TF_ASSIGN_OR_RETURN(parent_->evaluated_[power],
637                         ElementWiseBinaryOp(power, [](ElementwiseT lhs_el,
638                                                       ElementwiseT rhs_el) {
639                           return std::pow(lhs_el, rhs_el);
640                         }));
641     return Status::OK();
642   }
643 
644   template <
645       typename NativeT,
646       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleRemainder(HloInstruction * remainder)647   Status HandleRemainder(HloInstruction* remainder) {
648     TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
649                         ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
650                                                           ElementwiseT rhs_el) {
651                           return std::fmod(lhs_el, rhs_el);
652                         }));
653     return Status::OK();
654   }
655 
656   template <
657       typename NativeT,
658       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleRemainder(HloInstruction * remainder)659   Status HandleRemainder(HloInstruction* remainder) {
660     return InvalidArgument("Unsupported type for Remainder");
661   }
662 
HandleRemainder(HloInstruction * remainder)663   Status HandleRemainder(HloInstruction* remainder) override {
664     return HandleRemainder<ElementwiseT>(remainder);
665   }
666 
667   template <typename NativeT,
668             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
669                 nullptr>
HandleAnd(HloInstruction * and_)670   Status HandleAnd(HloInstruction* and_) {
671     TF_ASSIGN_OR_RETURN(
672         parent_->evaluated_[and_],
673         ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
674           return lhs_el & rhs_el;
675         }));
676     return Status::OK();
677   }
678 
679   template <typename NativeT, typename std::enable_if<std::is_floating_point<
680                                   NativeT>::value>::type* = nullptr>
HandleAnd(HloInstruction * and_)681   Status HandleAnd(HloInstruction* and_) {
682     TF_ASSIGN_OR_RETURN(
683         parent_->evaluated_[and_],
684         ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
685           return lhs_el && rhs_el;
686         }));
687     return Status::OK();
688   }
689 
690   template <
691       typename NativeT,
692       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleAnd(HloInstruction * and_)693   Status HandleAnd(HloInstruction* and_) {
694     return InvalidArgument("Unsupported type for And");
695   }
696 
HandleAnd(HloInstruction * and_)697   Status HandleAnd(HloInstruction* and_) override {
698     return HandleAnd<ElementwiseT>(and_);
699   }
700 
701   template <typename NativeT,
702             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
703                 nullptr>
HandleOr(HloInstruction * or_)704   Status HandleOr(HloInstruction* or_) {
705     TF_ASSIGN_OR_RETURN(
706         parent_->evaluated_[or_],
707         ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
708           return lhs_el | rhs_el;
709         }));
710     return Status::OK();
711   }
712 
713   template <typename NativeT, typename std::enable_if<std::is_floating_point<
714                                   NativeT>::value>::type* = nullptr>
HandleOr(HloInstruction * or_)715   Status HandleOr(HloInstruction* or_) {
716     TF_ASSIGN_OR_RETURN(
717         parent_->evaluated_[or_],
718         ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
719           return lhs_el || rhs_el;
720         }));
721     return Status::OK();
722   }
723 
724   template <
725       typename NativeT,
726       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleOr(HloInstruction * or_)727   Status HandleOr(HloInstruction* or_) {
728     return InvalidArgument("Unsupported type for Or");
729   }
730 
HandleOr(HloInstruction * or_)731   Status HandleOr(HloInstruction* or_) override {
732     return HandleOr<ElementwiseT>(or_);
733   }
734 
735   template <typename NativeT,
736             typename std::enable_if<
737                 std::is_integral<NativeT>::value &&
738                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleShiftLeft(HloInstruction * shl)739   Status HandleShiftLeft(HloInstruction* shl) {
740     TF_ASSIGN_OR_RETURN(
741         parent_->evaluated_[shl],
742         ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) {
743           return lhs_elem << rhs_elem;
744         }));
745     return Status::OK();
746   }
747 
748   template <typename NativeT,
749             typename std::enable_if<!std::is_integral<NativeT>::value ||
750                                     std::is_same<NativeT, bool>::value>::type* =
751                 nullptr>
HandleShiftLeft(HloInstruction *)752   Status HandleShiftLeft(HloInstruction*) {
753     return InvalidArgument("Unsupported type for ShiftLeft");
754   }
755 
HandleShiftLeft(HloInstruction * shl)756   Status HandleShiftLeft(HloInstruction* shl) override {
757     return HandleShiftLeft<ElementwiseT>(shl);
758   }
759   template <typename NativeT,
760             typename std::enable_if<
761                 std::is_integral<NativeT>::value &&
762                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleShiftRightArithmetic(HloInstruction * shr)763   Status HandleShiftRightArithmetic(HloInstruction* shr) {
764     typedef typename std::make_signed<NativeT>::type SignedT;
765     TF_ASSIGN_OR_RETURN(
766         parent_->evaluated_[shr],
767         ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
768           return static_cast<NativeT>(static_cast<SignedT>(lhs_elem) >>
769                                       rhs_elem);
770         }));
771     return Status::OK();
772   }
773 
774   template <typename NativeT,
775             typename std::enable_if<!std::is_integral<NativeT>::value ||
776                                     std::is_same<NativeT, bool>::value>::type* =
777                 nullptr>
HandleShiftRightArithmetic(HloInstruction *)778   Status HandleShiftRightArithmetic(HloInstruction*) {
779     return InvalidArgument("Unsupported type for ShiftRightArithmetic");
780   }
781 
HandleShiftRightArithmetic(HloInstruction * shra)782   Status HandleShiftRightArithmetic(HloInstruction* shra) override {
783     return HandleShiftRightArithmetic<ElementwiseT>(shra);
784   }
785 
786   template <typename NativeT,
787             typename std::enable_if<
788                 std::is_integral<NativeT>::value &&
789                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleShiftRightLogical(HloInstruction * shr)790   Status HandleShiftRightLogical(HloInstruction* shr) {
791     typedef typename std::make_unsigned<NativeT>::type UnsignedT;
792     TF_ASSIGN_OR_RETURN(
793         parent_->evaluated_[shr],
794         ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
795           // If shift amount is greater than the number of bits, then return 0.
796           if (rhs_elem >= sizeof(UnsignedT) * CHAR_BIT) {
797             return static_cast<NativeT>(0);
798           }
799           return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >>
800                                       rhs_elem);
801         }));
802     return Status::OK();
803   }
804 
805   template <typename NativeT,
806             typename std::enable_if<!std::is_integral<NativeT>::value ||
807                                     std::is_same<NativeT, bool>::value>::type* =
808                 nullptr>
HandleShiftRightLogical(HloInstruction *)809   Status HandleShiftRightLogical(HloInstruction*) {
810     return InvalidArgument("Unsupported type for ShiftRightLogical");
811   }
812 
HandleShiftRightLogical(HloInstruction * shrl)813   Status HandleShiftRightLogical(HloInstruction* shrl) override {
814     return HandleShiftRightLogical<ElementwiseT>(shrl);
815   }
816 
817   template <
818       typename NativeT,
819       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleClamp(HloInstruction * clamp)820   Status HandleClamp(HloInstruction* clamp) {
821     std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
822         clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
823           return std::fmax(low, std::fmin(value, high));
824         };
825     TF_ASSIGN_OR_RETURN(
826         parent_->evaluated_[clamp],
827         ElementwiseTernaryOp(clamp,
828                              std::move(ConvertTernaryFunction(clamp_op))));
829     return Status::OK();
830   }
831 
832   template <
833       typename NativeT,
834       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleClamp(HloInstruction *)835   Status HandleClamp(HloInstruction*) {
836     return InvalidArgument("Unsupported type for Clamp");
837   }
838 
HandleClamp(HloInstruction * clamp)839   Status HandleClamp(HloInstruction* clamp) override {
840     return HandleClamp<ElementwiseT>(clamp);
841   }
842 
HandleSelect(HloInstruction * select)843   Status HandleSelect(HloInstruction* select) override {
844     CHECK(!ShapeUtil::IsTuple(select->shape()));
845     std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
846         [](bool pred, ReturnT on_true, ReturnT on_false) {
847           if (pred) {
848             return on_true;
849           }
850           return on_false;
851         };
852     TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
853                         ElementwiseTernaryOp(select, std::move(select_op)));
854     return Status::OK();
855   }
856 
HandleReverse(HloInstruction * reverse)857   Status HandleReverse(HloInstruction* reverse) override {
858     const auto result_shape = reverse->shape();
859     const auto reverse_dimensions = reverse->dimensions();
860 
861     auto operand = reverse->operand(0);
862     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
863                         ShapeInference::InferReverseShape(operand->shape(),
864                                                           reverse_dimensions));
865 
866     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
867         << "return shape set to: " << ShapeUtil::HumanString(result_shape)
868         << " but is inferred to be: "
869         << ShapeUtil::HumanString(inferred_return_shape);
870 
871     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
872     auto result = Literal::CreateFromShape(result_shape);
873 
874     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
875         [&](tensorflow::gtl::ArraySlice<int64> out_index) {
876           std::vector<int64> from_index(out_index.begin(), out_index.end());
877           for (const int64 dim : reverse_dimensions) {
878             from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
879           }
880           return operand_literal.Get<ReturnT>(from_index);
881         }));
882 
883     parent_->evaluated_[reverse] = std::move(result);
884     return Status::OK();
885   }
886 
HandleConvolution(HloInstruction * conv)887   Status HandleConvolution(HloInstruction* conv) override {
888     auto lhs = conv->operand(0);
889     auto rhs = conv->operand(1);
890     const auto& window = conv->window();
891     const Shape& result_shape = conv->shape();
892     const Shape& lhs_shape = lhs->shape();
893     const Shape& rhs_shape = rhs->shape();
894 
895     TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
896     TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
897     CHECK(ShapeUtil::IsArray(lhs_shape));
898     CHECK(ShapeUtil::IsArray(rhs_shape));
899     CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
900     CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
901 
902     const auto& dnums = conv->convolution_dimension_numbers();
903     const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
904     CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size());
905     CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
906     CHECK_GE(num_spatial_dims, 0);
907     CHECK_EQ(window.dimensions_size(), num_spatial_dims);
908 
909     const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
910     const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
911 
912     CHECK_EQ(num_spatial_dims + 2, lhs_rank);
913     CHECK_EQ(num_spatial_dims + 2, rhs_rank);
914 
915     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
916                         ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
917                                                            window, dnums));
918     CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
919         << "return shape set to: " << ShapeUtil::HumanString(result_shape)
920         << " but is inferred to be: "
921         << ShapeUtil::HumanString(inferred_return_shape);
922 
923     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
924     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
925 
926     // Dimension number applicable for input (lhs).
927     const int64 input_batch_dim = dnums.input_batch_dimension();
928     const int64 input_z_dim = dnums.input_feature_dimension();
929     // Dimension number applicable for kernel (rhs).
930     const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
931     const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
932     // Dimension number applicable for output.
933     const int64 output_batch_dim = dnums.output_batch_dimension();
934     const int64 output_z_dim = dnums.output_feature_dimension();
935 
936     const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
937 
938     std::vector<int64> window_dimension_sizes;
939     for (auto i : dnums.kernel_spatial_dimensions()) {
940       window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
941     }
942 
943     const Shape& window_shape =
944         ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
945 
946     DimensionVector lhs_index(lhs_rank);
947     DimensionVector rhs_index(rhs_rank);
948     DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
949 
950     auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
951       ElementwiseT result_val = static_cast<ElementwiseT>(0);
952 
953       std::fill(lhs_index.begin(), lhs_index.end(), 0);
954       std::fill(rhs_index.begin(), rhs_index.end(), 0);
955       std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
956 
957       lhs_index[input_batch_dim] = out_index[output_batch_dim];
958       rhs_index[kernel_output_z_dim] = out_index[output_z_dim];
959 
960       // Convolve input feature with kernel.
961       do {
962         for (int64 iz = 0; iz < z_size; ++iz) {
963           lhs_index[input_z_dim] = iz;
964           rhs_index[kernel_input_z_dim] = iz;
965 
966           // Find corresponding spatial dimension index for input (lhs).
967           for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
968             // Spatial dimension number for input (lhs) and output.
969             const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
970             const int64 output_spatial_dim =
971                 dnums.output_spatial_dimensions(ki);
972 
973             // Calculate lhs (input) index without taking base dilation into
974             // account.
975             const auto& window_dim = window.dimensions(ki);
976             const int64 undilated_index =
977                 out_index[output_spatial_dim] * window_dim.stride() -
978                 window_dim.padding_low() +
979                 rhs_spatial_index[ki] * window_dim.window_dilation();
980             // Skip if the lhs (input) index is to be dilated.  As an
981             // optimization, skip this mod if there's no dilation.
982             if (window_dim.base_dilation() > 1 &&
983                 undilated_index % window_dim.base_dilation() != 0) {
984               goto cnt;
985             }
986 
987             // Calculate the actual lhs (input) index after dilation.  As an
988             // optimization, skip this integer divide if there's no dilation.
989             if (window_dim.base_dilation() > 1) {
990               lhs_index[input_spatial_dim] =
991                   undilated_index / window_dim.base_dilation();
992             } else {
993               lhs_index[input_spatial_dim] = undilated_index;
994             }
995 
996             // Skip if input index is not in bound.
997             if (!(lhs_index[input_spatial_dim] >= 0 &&
998                   lhs_index[input_spatial_dim] <
999                       lhs_shape.dimensions(input_spatial_dim))) {
1000               goto cnt;
1001             }
1002 
1003             rhs_index[dnums.kernel_spatial_dimensions(ki)] =
1004                 window_dim.window_reversal()
1005                     ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
1006                     : rhs_spatial_index[ki];
1007           }
1008 
1009           result_val +=
1010               static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
1011               static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
1012         }
1013       cnt : {}
1014       } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
1015 
1016       return static_cast<ReturnT>(result_val);
1017     };
1018 
1019     auto result = Literal::CreateFromShape(result_shape);
1020     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
1021 
1022     parent_->evaluated_[conv] = std::move(result);
1023     return Status::OK();
1024   }
1025 
HandleDot(HloInstruction * dot)1026   Status HandleDot(HloInstruction* dot) override {
1027     auto lhs = dot->operand(0);
1028     auto rhs = dot->operand(1);
1029     CHECK(ShapeUtil::IsArray(dot->shape()));
1030     CHECK(ShapeUtil::IsArray(lhs->shape()));
1031     CHECK(ShapeUtil::IsArray(rhs->shape()));
1032 
1033     const auto& dnums = dot->dot_dimension_numbers();
1034 
1035     const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
1036     const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
1037 
1038     CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
1039     CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
1040 
1041     // There must be 1 and only 1 Contracting dimension for lhs and rhs.
1042     CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1);
1043     CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1);
1044     const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
1045     const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
1046     // Contracted dimension sizes must be the same.
1047     CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension),
1048              rhs->shape().dimensions(rhs_contracting_dimension))
1049         << "lhs contracted dimension: "
1050         << lhs->shape().dimensions(lhs_contracting_dimension)
1051         << " rhs contracted dimension: "
1052         << rhs->shape().dimensions(rhs_contracting_dimension);
1053     const int64 contracted_dimension_size =
1054         lhs->shape().dimensions(lhs_contracting_dimension);
1055 
1056     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
1057     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
1058 
1059     auto result = Literal::CreateFromShape(dot->shape());
1060 
1061     CHECK_EQ(dnums.lhs_batch_dimensions_size(),
1062              dnums.rhs_batch_dimensions_size());
1063 
1064     std::vector<int64> lhs_non_contracting_dims;
1065     for (int64 i = 0; i < lhs_rank; i++) {
1066       if (i != lhs_contracting_dimension) {
1067         lhs_non_contracting_dims.push_back(i);
1068       }
1069     }
1070 
1071     std::vector<int64> rhs_non_batch_non_contracting_dims;
1072     tensorflow::gtl::FlatSet<int64> batch_dims_set(
1073         dnums.rhs_batch_dimensions().begin(),
1074         dnums.rhs_batch_dimensions().end());
1075     for (int64 i = 0; i < rhs_rank; i++) {
1076       if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
1077         rhs_non_batch_non_contracting_dims.push_back(i);
1078       }
1079     }
1080 
1081     const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
1082     const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
1083 
1084     DimensionVector lhs_index(lhs_rank);
1085     DimensionVector rhs_index(rhs_rank);
1086     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1087         [&](tensorflow::gtl::ArraySlice<int64> result_index) {
1088           ElementwiseT result_val = static_cast<ElementwiseT>(0);
1089 
1090           // Find the corresponding non-contracting indices for lhs and rhs.
1091           //
1092           // For `result_index`, its batch dimension, if exists, will be at the
1093           // same dimension as the batch dimension of lhs and rhs. More
1094           // specifically:
1095           // - For lhs, the non-contracting dimensions, including the batch
1096           // dimension have the same index as the `result_index`.
1097           // - For rhs, the batch dimension is set seperately from other
1098           // non-contracting dimensions, since these other non-contracting
1099           // dimensions in rhs follow the non-contracting dimensions of lhs in
1100           // the resulting index.
1101           //
1102           // As an example, for a resulting index:
1103           //  result_index [result_batch, result_x, result_y]
1104           // the effecting lhs and rhs indices are:
1105           //  lhs [result_batch, lhs_non_contracting_dim, contracting_dim
1106           //  rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
1107           // `result_x` is only affected by the lhs_non_contracting_dim and
1108           // likewise `result_y` only depends on rhs_non_contracting_dim.
1109           //
1110           // so we can look up the lhs and rhs indices by:
1111           //
1112           // lhs:
1113           //  batch index is the same as `result_batch`.
1114           //    non-contracting dimension is the same as
1115           //    result_index[lhs_non_contracting_dim]
1116           // rhs:
1117           //  batch index: the same as `result_batch`.
1118           //  non-contracting dimension index: *not* the same as
1119           //    result_index[rhs_non_contractng_dim], since the
1120           //    non-contracting dimensions of lhs are included in the
1121           //    result_index first. Instead, the non_contracting_dim of rhs must
1122           //    be calculated as following:
1123           //      lhs_non_contracting_dimensions_size +
1124           //      (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
1125           //
1126           //    Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
1127           //    the index offset to the result_index that only depends on
1128           //    the non_batch and non-contracting dimensions of rhs. -1 at the
1129           //    end translates size to index.
1130           for (auto i : lhs_non_contracting_dims) {
1131             lhs_index[i] = result_index[i];
1132           }
1133           for (auto i : dnums.rhs_batch_dimensions()) {
1134             rhs_index[i] = result_index[i];
1135           }
1136           for (auto i : rhs_non_batch_non_contracting_dims) {
1137             const int64 rhs_non_batch_non_contracting_dim =
1138                 lhs_non_contracting_size + (i - batch_dim_size) - 1;
1139             rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
1140           }
1141 
1142           // Accumulates resulting product along the contracted dimension.
1143           for (int64 i = 0; i < contracted_dimension_size; ++i) {
1144             lhs_index[lhs_contracting_dimension] = i;
1145             rhs_index[rhs_contracting_dimension] = i;
1146 
1147             result_val +=
1148                 static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
1149                 static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
1150           }
1151 
1152           return static_cast<ReturnT>(result_val);
1153         }));
1154 
1155     parent_->evaluated_[dot] = std::move(result);
1156     return Status::OK();
1157   }
1158 
HandlePad(HloInstruction * pad)1159   Status HandlePad(HloInstruction* pad) override {
1160     CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
1161     // Padding value must be scalar.
1162     CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
1163     CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
1164              pad->padding_config().dimensions_size());
1165 
1166     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1167                         ShapeInference::InferPadShape(
1168                             /*operand_shape=*/pad->operand(0)->shape(),
1169                             /*padding_value_shape=*/pad->operand(1)->shape(),
1170                             /*padding_config=*/pad->padding_config()));
1171     CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape))
1172         << "return shape is set to: " << ShapeUtil::HumanString(pad->shape())
1173         << "but is inferred to be: "
1174         << ShapeUtil::HumanString(inferred_return_shape);
1175 
1176     // Create new HLO of padded shape with padding value.
1177     ReturnT scalar =
1178         parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
1179     auto result = Literal::CreateFromShape(pad->shape());
1180     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1181         [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
1182           return scalar;
1183         }));
1184 
1185     const Literal& evaluated_operand =
1186         parent_->GetEvaluatedLiteralFor(pad->operand(0));
1187 
1188     std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
1189                                    0);
1190     std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
1191 
1192     // Loop through each element of the operand, assign them to the
1193     // corresponding index of the resulting padded literal.
1194     const PaddingConfig& pad_config = pad->padding_config();
1195 
1196     auto func = [&](const std::vector<int64>& input_index) {
1197       for (auto i = 0; i < input_index.size(); ++i) {
1198         // Interior padding occurs logically before edge padding, so in the case
1199         // of negative edge padding elements are removed from the
1200         // interior-padded operand.
1201         target_index[i] =
1202             pad_config.dimensions(i).edge_padding_low() +
1203             input_index[i] * (pad_config.dimensions(i).interior_padding() + 1);
1204 
1205         // Account for negative low and high padding: skip assignment if the
1206         // any target index is out of range.
1207         if (!(target_index[i] >= 0 &&
1208               target_index[i] < pad->shape().dimensions(i))) {
1209           return true;
1210         }
1211       }
1212       result->Set<ReturnT>(target_index,
1213                            evaluated_operand.Get<ReturnT>(input_index));
1214       return true;
1215     };
1216 
1217     std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(),
1218                                  0);
1219     std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1);
1220 
1221     ShapeUtil::ForEachIndex(
1222         evaluated_operand.shape(), zero_base,
1223         AsInt64Slice(evaluated_operand.shape().dimensions()), step, func);
1224 
1225     parent_->evaluated_[pad] = std::move(result);
1226     return Status::OK();
1227   }
1228 
HandleDynamicSlice(HloInstruction * dynamic_slice)1229   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override {
1230     auto operand = dynamic_slice->operand(0);
1231     auto start_indices = dynamic_slice->operand(1);
1232     auto result_shape = dynamic_slice->shape();
1233     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1234                         ShapeInference::InferDynamicSliceShape(
1235                             operand->shape(), start_indices->shape(),
1236                             dynamic_slice->dynamic_slice_sizes()));
1237     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1238         << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
1239         << "but is inferred to be: "
1240         << ShapeUtil::HumanString(inferred_return_shape);
1241     TF_RET_CHECK(
1242         primitive_util::IsIntegralType(start_indices->shape().element_type()));
1243 
1244     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1245     const Literal& start_indices_literal =
1246         parent_->GetEvaluatedLiteralFor(start_indices);
1247 
1248     switch (start_indices->shape().element_type()) {
1249       case S32: {
1250         TF_ASSIGN_OR_RETURN(
1251             parent_->evaluated_[dynamic_slice],
1252             DynamicSlice<int32>(operand_literal, start_indices_literal,
1253                                 result_shape));
1254       } break;
1255       case S64: {
1256         TF_ASSIGN_OR_RETURN(
1257             parent_->evaluated_[dynamic_slice],
1258             DynamicSlice<int64>(operand_literal, start_indices_literal,
1259                                 result_shape));
1260       } break;
1261       case U32: {
1262         TF_ASSIGN_OR_RETURN(
1263             parent_->evaluated_[dynamic_slice],
1264             DynamicSlice<uint32>(operand_literal, start_indices_literal,
1265                                  result_shape));
1266       } break;
1267       case U64: {
1268         TF_ASSIGN_OR_RETURN(
1269             parent_->evaluated_[dynamic_slice],
1270             DynamicSlice<uint64>(operand_literal, start_indices_literal,
1271                                  result_shape));
1272       } break;
1273       default:
1274         LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for "
1275                       "start_indices: "
1276                    << PrimitiveType_Name(start_indices->shape().element_type());
1277     }
1278 
1279     return Status::OK();
1280   }
1281 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1282   Status HandleDynamicUpdateSlice(
1283       HloInstruction* dynamic_update_slice) override {
1284     auto operand = dynamic_update_slice->operand(0);
1285     auto update = dynamic_update_slice->operand(1);
1286     auto start_indices = dynamic_update_slice->operand(2);
1287     auto result_shape = dynamic_update_slice->shape();
1288     TF_ASSIGN_OR_RETURN(
1289         auto inferred_return_shape,
1290         ShapeInference::InferDynamicUpdateSliceShape(
1291             operand->shape(), update->shape(), start_indices->shape()));
1292     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1293         << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
1294         << "but is inferred to be: "
1295         << ShapeUtil::HumanString(inferred_return_shape);
1296     TF_RET_CHECK(
1297         primitive_util::IsIntegralType(start_indices->shape().element_type()));
1298     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape()));
1299 
1300     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1301     const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update);
1302     const Literal& start_indices_literal =
1303         parent_->GetEvaluatedLiteralFor(start_indices);
1304 
1305     switch (start_indices->shape().element_type()) {
1306       case S32: {
1307         TF_ASSIGN_OR_RETURN(
1308             parent_->evaluated_[dynamic_update_slice],
1309             DynamicUpdateSlice<int32>(operand_literal, update_literal,
1310                                       start_indices_literal));
1311       } break;
1312       case S64: {
1313         TF_ASSIGN_OR_RETURN(
1314             parent_->evaluated_[dynamic_update_slice],
1315             DynamicUpdateSlice<int64>(operand_literal, update_literal,
1316                                       start_indices_literal));
1317       } break;
1318       case U32: {
1319         TF_ASSIGN_OR_RETURN(
1320             parent_->evaluated_[dynamic_update_slice],
1321             DynamicUpdateSlice<uint32>(operand_literal, update_literal,
1322                                        start_indices_literal));
1323       } break;
1324       case U64: {
1325         TF_ASSIGN_OR_RETURN(
1326             parent_->evaluated_[dynamic_update_slice],
1327             DynamicUpdateSlice<uint64>(operand_literal, update_literal,
1328                                        start_indices_literal));
1329       } break;
1330       default:
1331         LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for "
1332                       "start_indices: "
1333                    << PrimitiveType_Name(start_indices->shape().element_type());
1334     }
1335 
1336     return Status::OK();
1337   }
1338 
1339   template <typename NativeT>
MapImpl(HloInstruction * map)1340   StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
1341     auto operands = map->operands();
1342     HloComputation* computation = map->to_apply();
1343 
1344     auto result = Literal::CreateFromShape(map->shape());
1345 
1346     HloEvaluator embedded_evaluator;
1347     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1348         [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1349           std::vector<std::unique_ptr<Literal>> arg_literals;
1350           arg_literals.reserve(operands.size());
1351 
1352           // Construct scalar literal parameters to be passed to the map
1353           // computation.
1354           for (auto operand : operands) {
1355             const Literal& arg_literal =
1356                 parent_->GetEvaluatedLiteralFor(operand);
1357 
1358             auto curr_val = arg_literal.Get<NativeT>(multi_index);
1359             auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val);
1360 
1361             arg_literals.push_back(std::move(curr_val_literal));
1362           }
1363 
1364           std::unique_ptr<Literal> computed_result =
1365               embedded_evaluator
1366                   .Evaluate<std::unique_ptr<Literal>>(*computation,
1367                                                       arg_literals)
1368                   .ConsumeValueOrDie();
1369           // Clear visit states so that the we can use the evaluate again on
1370           // the same computation.
1371           embedded_evaluator.ResetVisitStates();
1372 
1373           return computed_result->Get<ReturnT>({});
1374         }));
1375     return std::move(result);
1376   }
1377 
HandleMap(HloInstruction * map)1378   Status HandleMap(HloInstruction* map) override {
1379     switch (map->operand(0)->shape().element_type()) {
1380       case PRED: {
1381         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map));
1382         break;
1383       }
1384       case U8: {
1385         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
1386         break;
1387       }
1388       case U32: {
1389         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
1390         break;
1391       }
1392       case U64: {
1393         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map));
1394         break;
1395       }
1396       case S8: {
1397         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
1398         break;
1399       }
1400       case S32: {
1401         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
1402         break;
1403       }
1404       case S64: {
1405         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
1406         break;
1407       }
1408       case F16: {
1409         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
1410                             MapImpl<Eigen::half>(map));
1411         break;
1412       }
1413       case F32: {
1414         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
1415         break;
1416       }
1417       case F64: {
1418         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map));
1419         break;
1420       }
1421       case C64: {
1422         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map));
1423         break;
1424       }
1425       default:
1426         LOG(FATAL) << "HandleMap: unhandled primitive type for "
1427                       "input operand: "
1428                    << PrimitiveType_Name(
1429                           map->operand(0)->shape().element_type());
1430     }
1431 
1432     return Status::OK();
1433   }
1434 
HandleReduce(HloInstruction * reduce)1435   Status HandleReduce(HloInstruction* reduce) override {
1436     auto arg = reduce->operand(0);
1437     auto init_value = reduce->operand(1);
1438     tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
1439     HloComputation* function = reduce->to_apply();
1440     TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
1441                  ShapeUtil::Rank(arg->shape()) - dimensions.size());
1442     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1443                         ShapeInference::InferReduceShape(
1444                             /*arg=*/arg->shape(),
1445                             /*init_value=*/init_value->shape(),
1446                             /*dimensions_to_reduce=*/dimensions,
1447                             /*to_apply=*/function->ComputeProgramShape()));
1448     TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
1449         << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
1450         << "but is inferred to be: "
1451         << ShapeUtil::HumanString(inferred_return_shape);
1452 
1453     const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
1454     VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
1455     const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
1456     VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
1457     TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
1458     auto init_scalar = init_literal.Get<ReturnT>({});
1459 
1460     auto result = Literal::CreateFromShape(reduce->shape());
1461 
1462     const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
1463     std::vector<int64> arg_dim_steps(arg_dimensions.size());
1464     std::vector<int64> arg_dim_counts(arg_dimensions.size());
1465     for (const int64 dim : dimensions) {
1466       arg_dim_steps[dim] = 1;
1467       arg_dim_counts[dim] = arg_dimensions[dim];
1468     }
1469 
1470     // Create mapping from result index to arg index.
1471     const int64 result_rank = ShapeUtil::Rank(result->shape());
1472     int64 result_dim = 0;
1473     std::vector<int64> result_to_arg_index(result_rank);
1474     for (int64 i = 0; i < arg_dimensions.size(); ++i) {
1475       if (arg_dim_steps[i] == 0) {
1476         result_to_arg_index[result_dim] = i;
1477         ++result_dim;
1478       }
1479     }
1480 
1481     HloEvaluator embedded_evaluator;
1482     // For each resulting dimension, calculate and assign computed value.
1483     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1484         [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1485           ReturnT result_val = init_scalar;
1486 
1487           std::vector<int64> base(arg_dimensions.size());
1488           for (int64 i = 0; i < multi_index.size(); ++i) {
1489             base[result_to_arg_index[i]] = multi_index[i];
1490           }
1491 
1492           auto func = [&](const std::vector<int64>& input_index) {
1493             auto curr_val = arg_literal.Get<ReturnT>(input_index);
1494 
1495             // Evaluate computation with specified literal operands.
1496             auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
1497             auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
1498             std::vector<const Literal*> args = {curr_val_literal.get(),
1499                                                 result_val_literal.get()};
1500 
1501             std::unique_ptr<Literal> computed_result =
1502                 embedded_evaluator.Evaluate<const Literal*>(*function, args)
1503                     .ConsumeValueOrDie();
1504             // Clear visit states so that the we can use the evaluate again on
1505             // the same computation.
1506             embedded_evaluator.ResetVisitStates();
1507 
1508             // Assign computed result to result_val.
1509             result_val = computed_result->Get<ReturnT>({});
1510 
1511             return true;
1512           };
1513 
1514           ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
1515                                   arg_dim_steps, func);
1516 
1517           return result_val;
1518         }));
1519 
1520     parent_->evaluated_[reduce] = std::move(result);
1521     return Status::OK();
1522   }
1523 
HandleSelectAndScatter(HloInstruction * select_and_scatter)1524   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
1525     auto operand = select_and_scatter->operand(0);
1526     auto source = select_and_scatter->operand(1);
1527     const Window& window = select_and_scatter->window();
1528 
1529     const Literal& init_literal =
1530         parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2));
1531     TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
1532     auto init_scalar = init_literal.Get<ReturnT>({});
1533 
1534     auto result = Literal::CreateFromShape(select_and_scatter->shape());
1535 
1536     // Initialize result array with the init value.
1537     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1538         [&](tensorflow::gtl::ArraySlice<int64> output_index) {
1539           return init_scalar;
1540         }));
1541 
1542     std::vector<int64> window_dimension_sizes;
1543     for (const auto& window_dimension : window.dimensions()) {
1544       window_dimension_sizes.push_back(window_dimension.size());
1545     }
1546     const Shape window_shape = ShapeUtil::MakeShape(
1547         operand->shape().element_type(), window_dimension_sizes);
1548 
1549     HloComputation* select = select_and_scatter->select();
1550     HloComputation* scatter = select_and_scatter->scatter();
1551 
1552     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1553     const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
1554 
1555     int64 rank = ShapeUtil::Rank(operand_literal.shape());
1556 
1557     HloEvaluator embedded_evaluator;
1558     DimensionVector source_index(rank);
1559 
1560     std::fill(source_index.begin(), source_index.end(), 0);
1561     do {
1562       // For each element in `source`, we place a window in `operand`. For each
1563       // window placement, we iterate inside the window twice:
1564       //
1565       // 1. Find the selected index by applying `select` function to all
1566       // elements. E.g., If the `select` function is GreaterEqual, the first
1567       // iteration through the window finds the biggest value and returns its
1568       // index.
1569       //
1570       // 2. Using the selected index, scatter value from `source` to result. We
1571       // do this by iterating through the window, and compare each index with
1572       // the selected index.
1573       tensorflow::gtl::optional<ReturnT> selected_val;
1574       tensorflow::gtl::optional<std::vector<int64>> selected_index;
1575 
1576       IterateThroughWindow(
1577           window_shape, window, operand_literal.shape(), source_index,
1578           [&](const std::vector<int64>& operand_index) {
1579             auto curr_val = operand_literal.Get<ReturnT>(operand_index);
1580             if (!selected_val) {
1581               selected_val = curr_val;
1582               selected_index = operand_index;
1583             }
1584             const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
1585             const auto selected_val_literal =
1586                 Literal::CreateR0<ReturnT>(*selected_val);
1587 
1588             const std::vector<const Literal*> args = {
1589                 curr_val_literal.get(), selected_val_literal.get()};
1590             std::unique_ptr<Literal> computed_result =
1591                 embedded_evaluator.Evaluate<const Literal*>(*select, args)
1592                     .ConsumeValueOrDie();
1593             bool selected = computed_result->Get<bool>({});
1594             if (selected) {
1595               selected_val = curr_val;
1596               selected_index = operand_index;
1597             }
1598             embedded_evaluator.ResetVisitStates();
1599           });
1600 
1601       IterateThroughWindow(
1602           window_shape, window, operand_literal.shape(), source_index,
1603           [&](const std::vector<int64>& operand_index) {
1604             if (std::equal(operand_index.begin(), operand_index.end(),
1605                            selected_index->begin())) {
1606               auto source = source_literal.Get<ReturnT>(source_index);
1607               auto scattered = result->Get<ReturnT>(operand_index);
1608               const auto source_literal = Literal::CreateR0<ReturnT>(source);
1609               const auto scattered_literal =
1610                   Literal::CreateR0<ReturnT>(scattered);
1611 
1612               const std::vector<const Literal*> args = {
1613                   source_literal.get(), scattered_literal.get()};
1614               std::unique_ptr<Literal> computed_result =
1615                   embedded_evaluator.Evaluate<const Literal*>(*scatter, args)
1616                       .ConsumeValueOrDie();
1617               result->Set(operand_index, computed_result->Get<ReturnT>({}));
1618               // Clear visit states so that the we can use the evaluator again
1619               // on the same computation.
1620               embedded_evaluator.ResetVisitStates();
1621             }
1622           });
1623     } while (IndexUtil::BumpIndices(source->shape(), &source_index));
1624 
1625     parent_->evaluated_[select_and_scatter] = std::move(result);
1626     return Status::OK();
1627   }
1628 
HandleReduceWindow(HloInstruction * reduce_window)1629   Status HandleReduceWindow(HloInstruction* reduce_window) override {
1630     auto operand = reduce_window->operand(0);
1631     const Window& window = reduce_window->window();
1632     HloComputation* function = reduce_window->to_apply();
1633     TF_ASSIGN_OR_RETURN(
1634         auto inferred_return_shape,
1635         ShapeInference::InferReduceWindowShape(
1636             /*operand_shape=*/reduce_window->operand(0)->shape(),
1637             /*init_value=*/reduce_window->operand(1)->shape(), window,
1638             /*to_apply_shape=*/function->ComputeProgramShape()));
1639     TF_RET_CHECK(
1640         ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
1641         << "return shape is set to: "
1642         << ShapeUtil::HumanStringWithLayout(reduce_window->shape())
1643         << "but is inferred to be: "
1644         << ShapeUtil::HumanStringWithLayout(inferred_return_shape);
1645 
1646     const Literal& operand_literal =
1647         parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
1648     VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString();
1649     const Literal& init_literal =
1650         parent_->GetEvaluatedLiteralFor(reduce_window->operand(1));
1651     VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
1652     TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
1653     auto init_scalar = init_literal.Get<ReturnT>({});
1654 
1655     auto result = Literal::CreateFromShape(reduce_window->shape());
1656 
1657     // Creates a Shape object from window, for iteration below.
1658     std::vector<int64> window_dimension_sizes;
1659     for (const auto& window_dimension : window.dimensions()) {
1660       window_dimension_sizes.push_back(window_dimension.size());
1661     }
1662     const Shape window_shape = ShapeUtil::MakeShape(
1663         operand->shape().element_type(), window_dimension_sizes);
1664 
1665     DimensionVector window_index(window.dimensions_size());
1666     DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
1667 
1668     HloEvaluator embedded_evaluator;
1669     // For each resulting dimension, calculate and assign computed value.
1670     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1671         [&](tensorflow::gtl::ArraySlice<int64> output_index) {
1672           ReturnT result_val = init_scalar;
1673 
1674           std::fill(window_index.begin(), window_index.end(), 0);
1675           std::fill(operand_index.begin(), operand_index.end(), 0);
1676 
1677           IterateThroughWindow(
1678               window_shape, window, operand_literal.shape(), output_index,
1679               [&](const std::vector<int64>& operand_index) {
1680                 auto curr_val = operand_literal.Get<ReturnT>(operand_index);
1681 
1682                 // Evaluate computation with specified literal operands.
1683                 const auto curr_val_literal =
1684                     Literal::CreateR0<ReturnT>(curr_val);
1685                 const auto result_val_literal =
1686                     Literal::CreateR0<ReturnT>(result_val);
1687                 const std::vector<const Literal*> args = {
1688                     curr_val_literal.get(), result_val_literal.get()};
1689                 std::unique_ptr<Literal> computed_result =
1690                     embedded_evaluator.Evaluate<const Literal*>(*function, args)
1691                         .ConsumeValueOrDie();
1692 
1693                 // Clear visit states so that the we can use the evaluate again
1694                 // on the same computation.
1695                 embedded_evaluator.ResetVisitStates();
1696 
1697                 result_val = computed_result->Get<ReturnT>({});
1698               });
1699 
1700           return result_val;
1701         }));
1702 
1703     parent_->evaluated_[reduce_window] = std::move(result);
1704     return Status::OK();
1705   }
1706 
HandleSlice(HloInstruction * slice)1707   Status HandleSlice(HloInstruction* slice) override {
1708     auto operand = slice->operand(0);
1709     const Shape& shape = slice->shape();
1710     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1711                         ShapeInference::InferSliceShape(
1712                             operand->shape(), slice->slice_starts(),
1713                             slice->slice_limits(), slice->slice_strides()));
1714     TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape))
1715         << "return shape set to: " << ShapeUtil::HumanString(shape)
1716         << " but is inferred to be: "
1717         << ShapeUtil::HumanString(inferred_return_shape);
1718 
1719     const int64 rank = ShapeUtil::Rank(operand->shape());
1720     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1721     auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
1722       DimensionVector operand_index(rank);
1723       for (int64 i = 0; i < rank; ++i) {
1724         operand_index[i] =
1725             slice->slice_starts(i) + out_index[i] * slice->slice_strides(i);
1726       }
1727       return operand_literal.Get<ReturnT>(operand_index);
1728     };
1729 
1730     auto result = Literal::CreateFromDimensions(
1731         shape.element_type(), AsInt64Slice(shape.dimensions()));
1732     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
1733     parent_->evaluated_[slice] = std::move(result);
1734     return Status::OK();
1735   }
1736 
1737   template <typename NativeT, typename std::enable_if<std::is_floating_point<
1738                                   NativeT>::value>::type* = nullptr>
HandleSin(HloInstruction * sin)1739   Status HandleSin(HloInstruction* sin) {
1740     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin],
1741                         ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) {
1742                           return std::sin(elem_operand);
1743                         }));
1744     return Status::OK();
1745   }
1746 
1747   template <
1748       typename NativeT,
1749       typename std::enable_if<std::is_integral<NativeT>::value ||
1750                               is_complex_t<NativeT>::value>::type* = nullptr>
HandleSin(HloInstruction * sin)1751   Status HandleSin(HloInstruction* sin) {
1752     return InvalidArgument("Unsupported type for Sin");
1753   }
1754 
HandleSin(HloInstruction * sin)1755   Status HandleSin(HloInstruction* sin) override {
1756     return HandleSin<ElementwiseT>(sin);
1757   }
1758 
1759   template <typename NativeT, typename std::enable_if<std::is_floating_point<
1760                                   NativeT>::value>::type* = nullptr>
HandleCos(HloInstruction * cos)1761   Status HandleCos(HloInstruction* cos) {
1762     TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos],
1763                         ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) {
1764                           return std::cos(elem_operand);
1765                         }));
1766     return Status::OK();
1767   }
1768 
1769   template <
1770       typename NativeT,
1771       typename std::enable_if<std::is_integral<NativeT>::value ||
1772                               is_complex_t<NativeT>::value>::type* = nullptr>
HandleCos(HloInstruction * cos)1773   Status HandleCos(HloInstruction* cos) {
1774     return InvalidArgument("Unsupported type for Cos");
1775   }
1776 
HandleCos(HloInstruction * cos)1777   Status HandleCos(HloInstruction* cos) override {
1778     return HandleCos<ElementwiseT>(cos);
1779   }
1780 
1781   template <typename NativeT, typename std::enable_if<std::is_same<
1782                                   float, NativeT>::value>::type* = nullptr>
HandleReducePrecision(HloInstruction * reduce_precision)1783   Status HandleReducePrecision(HloInstruction* reduce_precision) {
1784     TF_ASSIGN_OR_RETURN(
1785         parent_->evaluated_[reduce_precision],
1786         ElementWiseUnaryOp(reduce_precision, [reduce_precision](
1787                                                  ElementwiseT elem) {
1788           uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem);
1789           const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
1790           const uint32_t exponent_bits = reduce_precision->exponent_bits();
1791 
1792           // Code is based on the CPU/GPU implementation in LLVM-emitting code.
1793           //
1794           // Bits in float type:
1795           //   mantissa : bits [0:22]
1796           //   exponent : bits [23:30]
1797           //   sign     : bits [31]
1798           if (mantissa_bits < 23) {
1799             const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
1800 
1801             // Compute rounding bias for round-to-nearest with ties to even.
1802             // This is equal to a base value of 0111... plus one bit if the last
1803             // remaining mantissa bit is 1.
1804             const uint32_t base_rounding_bias =
1805                 (last_mantissa_bit_mask >> 1) - 1;
1806             const uint32_t x_last_mantissa_bit =
1807                 (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
1808             const uint32_t x_rounding_bias =
1809                 x_last_mantissa_bit + base_rounding_bias;
1810 
1811             // Add rounding bias, and mask out truncated bits.  Note that the
1812             // case where adding the rounding bias overflows into the exponent
1813             // bits is correct; the non-masked mantissa bits will all be zero,
1814             // and the exponent will be incremented by one.
1815             const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
1816             value_as_int = value_as_int + x_rounding_bias;
1817             value_as_int = value_as_int & truncation_mask;
1818           }
1819           if (exponent_bits < 8) {
1820             // Masks for f32 values.
1821             const uint32_t f32_sign_bit_mask = 1u << 31;
1822             const uint32_t f32_exp_bits_mask = 0xffu << 23;
1823 
1824             // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
1825             // most- significant bit -- is equal to 1.0f for all exponent sizes.
1826             // Adding 2^(n-1)-1 to this gives us the highest non-infinite
1827             // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from
1828             // this gives us the lowest' exponent (corresponding to 0.0f).
1829             //
1830             // Thus, the f32 exponent corresponding to the highest non-infinite
1831             // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
1832             // exponent corresponding to the lowest exponent for a bit size of n
1833             // is (2^7-1) - 2^(n-1)-1.
1834             //
1835             // Note that we have already checked that exponents_bits >= 1.
1836             const uint32_t f32_exponent_bias = (1 << 7) - 1;
1837             const uint32_t reduced_exponent_bias =
1838                 (1 << (exponent_bits - 1)) - 1;
1839             const uint32_t reduced_max_exponent =
1840                 f32_exponent_bias + reduced_exponent_bias;
1841             const uint32_t reduced_min_exponent =
1842                 f32_exponent_bias - reduced_exponent_bias;
1843 
1844             // Do we overflow or underflow?
1845             const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
1846             const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
1847             const bool x_underflows =
1848                 x_exponent <= (reduced_min_exponent << 23);
1849 
1850             // Compute appropriately-signed values of zero and infinity.
1851             const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
1852             const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
1853 
1854             // Force to zero or infinity if overflow or underflow.  (Note that
1855             // this truncates all denormal values to zero, rather than rounding
1856             // them.)
1857             value_as_int = x_overflows ? x_signed_inf : value_as_int;
1858             value_as_int = x_underflows ? x_signed_zero : value_as_int;
1859           }
1860 
1861           float reduced_result = tensorflow::bit_cast<float>(value_as_int);
1862           if (std::isnan(elem)) {
1863             reduced_result = mantissa_bits > 0
1864                                  ? elem
1865                                  : std::numeric_limits<float>::infinity();
1866           }
1867           return reduced_result;
1868         }));
1869     return Status::OK();
1870   }
1871 
1872   template <typename NativeT, typename std::enable_if<std::is_same<
1873                                   double, NativeT>::value>::type* = nullptr>
HandleReducePrecision(HloInstruction * reduce_precision)1874   Status HandleReducePrecision(HloInstruction* reduce_precision) {
1875     return InvalidArgument("Double not supported for reduce precision");
1876   }
1877 
1878   template <
1879       typename NativeT,
1880       typename std::enable_if<std::is_integral<NativeT>::value ||
1881                               is_complex_t<NativeT>::value>::type* = nullptr>
HandleReducePrecision(HloInstruction * reduce_precision)1882   Status HandleReducePrecision(HloInstruction* reduce_precision) {
1883     return InvalidArgument("Unsupported type for reduce precision");
1884   }
1885 
HandleReducePrecision(HloInstruction * reduce_precision)1886   Status HandleReducePrecision(HloInstruction* reduce_precision) override {
1887     return HandleReducePrecision<ElementwiseT>(reduce_precision);
1888   }
1889 
1890  private:
1891   template <typename IndexT>
DynamicSlice(const Literal & operand_literal,const Literal & start_indices_literal,const Shape & result_shape)1892   StatusOr<std::unique_ptr<Literal>> DynamicSlice(
1893       const Literal& operand_literal, const Literal& start_indices_literal,
1894       const Shape& result_shape) {
1895     auto start_indices_typed = start_indices_literal.data<IndexT>();
1896     std::vector<int64> start(start_indices_typed.begin(),
1897                              start_indices_typed.end());
1898 
1899     std::vector<int64> operand_indices(start.size());
1900 
1901     auto result = Literal::CreateFromShape(result_shape);
1902     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1903         [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1904           for (int64 i = 0; i < operand_indices.size(); ++i) {
1905             CHECK_GE(multi_index[i] + start[i], 0);
1906             // Mod is only used here to be consistent with the existing
1907             // backends' behavior.
1908             operand_indices[i] = (multi_index[i] + start[i]) %
1909                                  operand_literal.shape().dimensions(i);
1910           }
1911 
1912           auto result = operand_literal.Get<ReturnT>(operand_indices);
1913           return result;
1914         }));
1915 
1916     return std::move(result);
1917   }
1918 
1919   template <typename IndexT>
DynamicUpdateSlice(const Literal & operand_literal,const Literal & update_literal,const Literal & start_indices_literal)1920   StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
1921       const Literal& operand_literal, const Literal& update_literal,
1922       const Literal& start_indices_literal) {
1923     auto start_indices_typed = start_indices_literal.data<IndexT>();
1924     const std::vector<int64> start(start_indices_typed.begin(),
1925                                    start_indices_typed.end());
1926 
1927     auto result = operand_literal.CloneToUnique();
1928     std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
1929 
1930     auto func = [&](const std::vector<int64>& update_index) {
1931       std::transform(update_index.begin(), update_index.end(), start.begin(),
1932                      result_index.begin(), std::plus<int64>());
1933 
1934       result->Set<ReturnT>(result_index,
1935                            update_literal.Get<ReturnT>(update_index));
1936       return true;
1937     };
1938 
1939     std::vector<int64> base(update_literal.shape().dimensions_size(), 0);
1940     std::vector<int64> step(update_literal.shape().dimensions_size(), 1);
1941     ShapeUtil::ForEachIndex(update_literal.shape(), base,
1942                             AsInt64Slice(update_literal.shape().dimensions()),
1943                             step, func);
1944 
1945     return std::move(result);
1946   }
1947 
ElementWiseUnaryOp(HloInstruction * instruction,const std::function<ElementwiseT (ElementwiseT)> & unary_op)1948   StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
1949       HloInstruction* instruction,
1950       const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
1951     const Literal& operand_literal =
1952         parent_->GetEvaluatedLiteralFor(instruction->operand(0));
1953     TF_ASSIGN_OR_RETURN(
1954         auto result_literal,
1955         (ElementWiseUnaryOpImpl<ReturnT, ReturnT>(
1956             instruction, ConvertUnaryFunction(unary_op), operand_literal)));
1957 
1958     return std::move(result_literal);
1959   }
1960 
ElementWiseBinaryOp(HloInstruction * instruction,const std::function<ElementwiseT (ElementwiseT,ElementwiseT)> & binary_op)1961   StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
1962       HloInstruction* instruction,
1963       const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
1964           binary_op) {
1965     const auto shape = instruction->shape();
1966     const auto* lhs = instruction->operand(0);
1967     const auto* rhs = instruction->operand(1);
1968 
1969     // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast
1970     // is removed.
1971     if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) &&
1972           ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
1973       return Unimplemented(
1974           "Implicit broadcasting is currently unsupported in HLO evaluator "
1975           "Shape Mismatch: %s vs %s vs %s: ",
1976           ShapeUtil::HumanString(shape).c_str(),
1977           ShapeUtil::HumanString(lhs->shape()).c_str(),
1978           ShapeUtil::HumanString(rhs->shape()).c_str());
1979     }
1980 
1981     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
1982     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
1983 
1984     auto result = Literal::CreateFromShape(shape);
1985 
1986     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1987         [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1988           return ConvertBinaryFunction(binary_op)(
1989               lhs_literal.Get<ReturnT>(multi_index),
1990               rhs_literal.Get<ReturnT>(multi_index));
1991         }));
1992     return std::move(result);
1993   }
1994 
1995   template <typename LhsType, typename RhsType, typename EhsType>
ElementwiseTernaryOp(HloInstruction * instruction,const std::function<ReturnT (LhsType,RhsType,EhsType)> & ternary_op)1996   StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
1997       HloInstruction* instruction,
1998       const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
1999     const auto shape = instruction->shape();
2000     const auto* lhs = instruction->operand(0);
2001     const auto* rhs = instruction->operand(1);
2002     const auto* ehs = instruction->operand(2);
2003 
2004     // TODO(b/35950897, b/27796129): add DCHECK back once implicit
2005     // broadcast is removed.
2006     if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) &&
2007           ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) &&
2008           ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) {
2009       return Unimplemented(
2010           "Implicit broadcasting is currently unsupported in HLO evaluator "
2011           "Shape Mismatch: %s vs %s vs %s vs %s: ",
2012           ShapeUtil::HumanString(shape).c_str(),
2013           ShapeUtil::HumanString(lhs->shape()).c_str(),
2014           ShapeUtil::HumanString(rhs->shape()).c_str(),
2015           ShapeUtil::HumanString(ehs->shape()).c_str());
2016     }
2017 
2018     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
2019     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
2020     const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
2021 
2022     auto result = Literal::CreateFromShape(shape);
2023 
2024     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
2025         [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
2026           return ternary_op(lhs_literal.Get<LhsType>(multi_index),
2027                             rhs_literal.Get<RhsType>(multi_index),
2028                             ehs_literal.Get<EhsType>(multi_index));
2029         }));
2030 
2031     return std::move(result);
2032   }
2033 
2034   HloEvaluator* parent_;
2035 };  // class HloEvaluator::TypedVisitor
2036 
HloEvaluator()2037 HloEvaluator::HloEvaluator() {
2038   typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this);
2039   typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this);
2040   typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2041     return Unimplemented("HloEvaluator: unhandled primitive type: U16.");
2042   });
2043   typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this);
2044   typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this);
2045   typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this);
2046   typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2047     return Unimplemented("HloEvaluator: unhandled primitive type: S16.");
2048   });
2049   typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
2050   typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
2051   typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this);
2052   typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
2053   typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
2054   typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
2055 
2056   // Most of the evaluator computations we use don't support BF16 (e.g.,
2057   // std::ceil, std::tanh). To make evaluator work with BF16, we set all
2058   // elementwise computations to be done in F32 and do BF16<->F32 conversion
2059   // around the input and the output of the computations.
2060   typed_visitors_[BF16] = MakeUnique<TypedVisitor<bfloat16, float>>(this);
2061   typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2062     return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
2063   });
2064   typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2065     return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE.");
2066   });
2067 }
2068 
2069 template <typename LiteralPtr>
Evaluate(const HloModule & module,tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals)2070 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2071     const HloModule& module,
2072     tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) {
2073   XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
2074 
2075   evaluated_.clear();
2076   arg_literals_.clear();
2077   for (const auto& literal_ptr : arg_literals) {
2078     arg_literals_.push_back(&*literal_ptr);
2079   }
2080 
2081   TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
2082 
2083   return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
2084       .CloneToUnique();
2085 }
2086 
2087 template <typename LiteralPtr>
Evaluate(const HloComputation & computation,tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals)2088 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2089     const HloComputation& computation,
2090     tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) {
2091   XLA_VLOG_LINES(
2092       2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
2093 
2094   evaluated_.clear();
2095   arg_literals_.clear();
2096   for (const auto& literal_ptr : arg_literals) {
2097     arg_literals_.push_back(&*literal_ptr);
2098   }
2099 
2100   TF_RETURN_IF_ERROR(computation.Accept(this));
2101   return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
2102 }
2103 
2104 template <typename LiteralPtr>
Evaluate(HloInstruction * instruction,tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals)2105 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2106     HloInstruction* instruction,
2107     tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) {
2108   TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
2109   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
2110 
2111   evaluated_.clear();
2112   arg_literals_.clear();
2113   for (const auto& literal_ptr : arg_literals) {
2114     arg_literals_.push_back(&*literal_ptr);
2115   }
2116 
2117   // Evaluate operands of Parameter type against the input literals which
2118   // caches the evaluated literal results.
2119   for (const auto operand : instruction->operands()) {
2120     if (operand->opcode() == HloOpcode::kParameter) {
2121       const Literal* input_literal = arg_literals_[operand->parameter_number()];
2122       VLOG(2) << "Parameter operand evaluated to: "
2123               << input_literal->ToString();
2124       TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
2125 
2126       evaluated_[operand] = input_literal->CloneToUnique();
2127     }
2128   }
2129 
2130   TF_RETURN_IF_ERROR(Preprocess(instruction));
2131   TF_RETURN_IF_ERROR(instruction->Visit(this));
2132   TF_RETURN_IF_ERROR(Postprocess(instruction));
2133   return GetEvaluatedLiteralFor(instruction).CloneToUnique();
2134 }
2135 
Evaluate(HloInstruction * instruction)2136 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2137     HloInstruction* instruction) {
2138   if (instruction->opcode() == HloOpcode::kParameter) {
2139     return tensorflow::errors::FailedPrecondition(
2140         "Cannot evaluate a parameter.");
2141   }
2142   if (!hlo_query::AllOperandsAreConstants(*instruction)) {
2143     return tensorflow::errors::FailedPrecondition(
2144         "Not all operands are constants.");
2145   }
2146   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
2147 
2148   arg_literals_.clear();
2149   evaluated_.clear();
2150 
2151   TF_RETURN_IF_ERROR(Preprocess(instruction));
2152   TF_RETURN_IF_ERROR(instruction->Visit(this));
2153   TF_RETURN_IF_ERROR(Postprocess(instruction));
2154   return GetEvaluatedLiteralFor(instruction).CloneToUnique();
2155 }
2156 
TryEvaluate(HloInstruction * instruction)2157 std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
2158     HloInstruction* instruction) {
2159   auto result_or = Evaluate(instruction);
2160   if (!result_or.ok()) {
2161     VLOG(1) << "TryEvaluate failed:" << result_or.status();
2162     return nullptr;
2163   }
2164 
2165   return result_or.ConsumeValueOrDie();
2166 }
2167 
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)2168 StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
2169     const HloInstruction* instruction,
2170     const std::unordered_map<const HloInstruction*, const Literal*>&
2171         substitutions) {
2172   std::vector<std::unique_ptr<HloInstruction>> owned_operands;
2173   for (const HloInstruction* operand : instruction->operands()) {
2174     auto it = substitutions.find(operand);
2175     if (it == substitutions.end()) {
2176       owned_operands.push_back(operand->Clone());
2177     } else {
2178       owned_operands.push_back(
2179           HloInstruction::CreateConstant(it->second->CloneToUnique()));
2180     }
2181   }
2182 
2183   std::vector<HloInstruction*> operands;
2184   operands.reserve(owned_operands.size());
2185   for (auto& operand : owned_operands) {
2186     operands.push_back(operand.get());
2187   }
2188 
2189   std::unique_ptr<HloInstruction> cloned_instruction =
2190       instruction->CloneWithNewOperands(instruction->shape(), operands);
2191   auto result = Evaluate(cloned_instruction.get());
2192 
2193   // Clean up our cloned instructions before returning.
2194   cloned_instruction->DetachFromOperands();
2195   for (auto& operand : owned_operands) {
2196     operand->DetachFromOperands();
2197   }
2198 
2199   return result;
2200 }
2201 
HandleParameter(HloInstruction * parameter)2202 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
2203   CHECK_LT(parameter->parameter_number(), arg_literals_.size());
2204   const Literal* input_literal = arg_literals_[parameter->parameter_number()];
2205   VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
2206   DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()))
2207       << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape())
2208       << ", but input literal shape is: "
2209       << ShapeUtil::HumanString(input_literal->shape());
2210 
2211   evaluated_[parameter] = input_literal->CloneToUnique();
2212   return Status::OK();
2213 }
2214 
HandleConstant(HloInstruction *)2215 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
2216 
HandleReshape(HloInstruction * reshape)2217 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
2218   TF_ASSIGN_OR_RETURN(
2219       evaluated_[reshape],
2220       GetEvaluatedLiteralFor(reshape->operand(0))
2221           .Reshape(AsInt64Slice(reshape->shape().dimensions())));
2222   return Status::OK();
2223 }
2224 
HandleTranspose(HloInstruction * transpose)2225 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
2226   evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
2227                               .Transpose(transpose->dimensions());
2228   return Status::OK();
2229 }
2230 
HandleConcatenate(HloInstruction * concatenate)2231 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
2232   tensorflow::gtl::ArraySlice<HloInstruction*> operands(
2233       concatenate->operands());
2234   // The result concatenate dimension is going to be the sum of all
2235   // concatenate dimensions of the operands taking part of the operation.
2236   const Shape& reference_shape = operands[0]->shape();
2237   CHECK(!ShapeUtil::IsTuple(reference_shape));
2238   const int64 rank = ShapeUtil::Rank(reference_shape);
2239   const int64 concat_dim = concatenate->dimensions()[0];
2240   CHECK_GE(concat_dim, 0);
2241   CHECK_LT(concat_dim, rank);
2242 
2243   DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
2244                                     reference_shape.dimensions().end());
2245 
2246   for (int64 i = 1; i < operands.size(); ++i) {
2247     const Shape& operand_shape = operands[i]->shape();
2248     CHECK(!ShapeUtil::IsTuple(operand_shape));
2249     // Accumulate the concat dimension from all tensors taking part to the
2250     // operation.
2251     concat_dimensions[concat_dim] +=
2252         ShapeUtil::GetDimension(operand_shape, concat_dim);
2253   }
2254 
2255   auto result_literal = Literal::CreateFromDimensions(
2256       reference_shape.element_type(), concat_dimensions);
2257   DimensionVector source_indices(rank, 0);
2258   DimensionVector dest_indices(concat_dimensions.size(), 0);
2259 
2260   for (auto operand : operands) {
2261     const Shape& operand_shape = operand->shape();
2262     TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
2263         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
2264         AsInt64Slice(operand_shape.dimensions())));
2265     dest_indices[concat_dim] +=
2266         ShapeUtil::GetDimension(operand_shape, concat_dim);
2267   }
2268 
2269   evaluated_[concatenate] = std::move(result_literal);
2270   return Status::OK();
2271 }
2272 
HandleIsFinite(HloInstruction * is_finite)2273 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
2274   auto operand = is_finite->operand(0);
2275   if (!ShapeUtil::ElementIsFloating(operand->shape())) {
2276     return InvalidArgument(
2277         "expected element type in shape to be float for IsFinite op, got: %s",
2278         PrimitiveType_Name(operand->shape().element_type()).c_str());
2279   }
2280 
2281   switch (operand->shape().element_type()) {
2282     case F16:
2283       return Unimplemented("unhandled primitive type: F16.");
2284     case F32: {
2285       auto result_or = ElementWiseUnaryOpImpl<bool, float>(
2286           is_finite,
2287           [](float elem_operand) { return std::isfinite(elem_operand); },
2288           GetEvaluatedLiteralFor(operand));
2289       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
2290       break;
2291     }
2292     case F64: {
2293       auto result_or = ElementWiseUnaryOpImpl<bool, double>(
2294           is_finite,
2295           [](double elem_operand) { return std::isfinite(elem_operand); },
2296           GetEvaluatedLiteralFor(operand));
2297       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
2298       break;
2299     }
2300     default:
2301       LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: "
2302                  << PrimitiveType_Name(operand->shape().element_type());
2303   }
2304 
2305   return Status::OK();
2306 }
2307 
HandleCompare(HloInstruction * compare)2308 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
2309   HloOpcode opcode = compare->opcode();
2310   auto lhs = compare->operand(0);
2311   auto rhs = compare->operand(1);
2312   // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
2313   // removed.
2314   if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
2315         ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
2316     return Unimplemented(
2317         "Implicit broadcasting is currently unsupported in HLO evaluator "
2318         "Shape Mismatch: %s vs %s vs %s",
2319         ShapeUtil::HumanString(compare->shape()).c_str(),
2320         ShapeUtil::HumanString(lhs->shape()).c_str(),
2321         ShapeUtil::HumanString(rhs->shape()).c_str());
2322   }
2323 
2324   TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
2325 
2326   const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
2327   const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
2328 
2329   // Note here we switch on the operand's type.
2330   switch (lhs->shape().element_type()) {
2331     case PRED: {
2332       TF_ASSIGN_OR_RETURN(
2333           evaluated_[compare],
2334           Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
2335     } break;
2336     case U8: {
2337       TF_ASSIGN_OR_RETURN(
2338           evaluated_[compare],
2339           Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
2340     } break;
2341     case U16:
2342       return Unimplemented("unhandled primitive type: U16.");
2343     case U32: {
2344       TF_ASSIGN_OR_RETURN(
2345           evaluated_[compare],
2346           Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
2347     } break;
2348     case U64: {
2349       TF_ASSIGN_OR_RETURN(
2350           evaluated_[compare],
2351           Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
2352     } break;
2353     case S8: {
2354       TF_ASSIGN_OR_RETURN(
2355           evaluated_[compare],
2356           Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
2357     } break;
2358     case S16:
2359       return Unimplemented("unhandled primitive type: S16.");
2360     case S32: {
2361       TF_ASSIGN_OR_RETURN(
2362           evaluated_[compare],
2363           Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
2364     } break;
2365     case S64: {
2366       TF_ASSIGN_OR_RETURN(
2367           evaluated_[compare],
2368           Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
2369     } break;
2370     case F16:
2371       return Unimplemented("unhandled primitive type: F16.");
2372     case F32: {
2373       TF_ASSIGN_OR_RETURN(
2374           evaluated_[compare],
2375           Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
2376     } break;
2377     case F64: {
2378       TF_ASSIGN_OR_RETURN(
2379           evaluated_[compare],
2380           Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
2381     } break;
2382     case C64: {
2383       TF_ASSIGN_OR_RETURN(evaluated_[compare],
2384                           Compare<complex64>(compare->shape(), opcode,
2385                                              lhs_literal, rhs_literal));
2386     } break;
2387     default:
2388       LOG(FATAL) << "HandleCompare: unknown primitive type: "
2389                  << PrimitiveType_Name(lhs->shape().element_type());
2390   }
2391 
2392   return Status::OK();
2393 }
2394 
HandleTuple(HloInstruction * tuple)2395 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
2396   std::vector<const Literal*> operand_literals;
2397   for (auto operand : tuple->operands()) {
2398     operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
2399   }
2400 
2401   evaluated_[tuple] = Literal::MakeTuple(operand_literals);
2402   return Status::OK();
2403 }
2404 
HandleGetTupleElement(HloInstruction * get_tuple_element)2405 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
2406   const auto result_shape = get_tuple_element->shape();
2407   const int64 index = get_tuple_element->tuple_index();
2408 
2409   auto operand = get_tuple_element->operand(0);
2410   TF_ASSIGN_OR_RETURN(
2411       auto inferred_return_shape,
2412       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
2413   TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
2414       << "return shape set to: " << ShapeUtil::HumanString(result_shape)
2415       << " but is inferred to be: "
2416       << ShapeUtil::HumanString(inferred_return_shape);
2417 
2418   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
2419 
2420   evaluated_[get_tuple_element] = MakeUnique<Literal>(
2421       ShapeUtil::GetTupleElementShape(operand->shape(), index));
2422   return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
2423                                                  /*dest_shape_index=*/{},
2424                                                  /*src_shape_index=*/{index});
2425 }
2426 
HandleCopy(HloInstruction * copy)2427 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
2428   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
2429 
2430   auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
2431   evaluated_[copy] = std::move(result);
2432   return Status::OK();
2433 }
2434 
Preprocess(HloInstruction * hlo)2435 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
2436   VLOG(2) << "About to visit HLO: " << hlo->ToString();
2437   return Status::OK();
2438 }
2439 
Postprocess(HloInstruction * hlo)2440 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
2441   VLOG(2) << "Finished visiting " << hlo->ToString()
2442           << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
2443   return Status::OK();
2444 }
2445 
2446 // Explicit instantiation of templatized Evaluate* methods.
2447 //
2448 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
2449     const Literal*>(const HloModule& module,
2450                     tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
2451 template StatusOr<std::unique_ptr<Literal>>
2452 HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
2453     const HloModule& module,
2454     tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals);
2455 
2456 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
2457     const Literal*>(const HloComputation& computation,
2458                     tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
2459 template StatusOr<std::unique_ptr<Literal>>
2460 HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
2461     const HloComputation& computation,
2462     tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals);
2463 
2464 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
2465     const Literal*>(HloInstruction* instruction,
2466                     tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
2467 template StatusOr<std::unique_ptr<Literal>>
2468 HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
2469     HloInstruction* instruction,
2470     tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals);
2471 
2472 }  // namespace xla
2473