• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdlib>
20 #include <functional>
21 #include <iterator>
22 #include <string>
23 #include <type_traits>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/index_util.h"
31 #include "tensorflow/compiler/xla/layout_util.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/map_util.h"
34 #include "tensorflow/compiler/xla/primitive_util.h"
35 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
36 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
37 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/hlo_query.h"
41 #include "tensorflow/compiler/xla/service/shape_inference.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/core/lib/core/bitmap.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/protobuf.h"
52 #include "tensorflow/core/platform/types.h"
53 
54 namespace xla {
55 
56 namespace {
57 
58 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)59 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
60                           LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
61   std::function<bool(OperandT, OperandT)> compare_op;
62   switch (direction) {
63     case ComparisonDirection::kEq:
64       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
65         return lhs_el == rhs_el;
66       };
67       break;
68     case ComparisonDirection::kNe:
69       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
70         return lhs_el != rhs_el;
71       };
72       break;
73     case ComparisonDirection::kGe:
74       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
75         return lhs_el >= rhs_el;
76       };
77       break;
78     case ComparisonDirection::kGt:
79       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
80         return lhs_el > rhs_el;
81       };
82       break;
83     case ComparisonDirection::kLe:
84       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
85         return lhs_el <= rhs_el;
86       };
87       break;
88     case ComparisonDirection::kLt:
89       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
90         return lhs_el < rhs_el;
91       };
92       break;
93   }
94 
95   Literal result(shape);
96   TF_RETURN_IF_ERROR(
97       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
98         return compare_op(lhs_literal.Get<OperandT>(multi_index),
99                           rhs_literal.Get<OperandT>(multi_index));
100       }));
101 
102   return std::move(result);
103 }
104 
105 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)106 StatusOr<Literal> Compare<complex64>(const Shape& shape,
107                                      ComparisonDirection direction,
108                                      LiteralSlice lhs_literal,
109                                      LiteralSlice rhs_literal) {
110   std::function<bool(complex64, complex64)> compare_op;
111   switch (direction) {
112     case ComparisonDirection::kEq:
113       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
114         return lhs_el == rhs_el;
115       };
116       break;
117     case ComparisonDirection::kNe:
118       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
119         return lhs_el != rhs_el;
120       };
121       break;
122     default:
123       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
124                  << ComparisonDirectionToString(direction);
125   }
126 
127   Literal result(shape);
128   TF_RETURN_IF_ERROR(
129       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
130         return compare_op(lhs_literal.Get<complex64>(multi_index),
131                           rhs_literal.Get<complex64>(multi_index));
132       }));
133 
134   return std::move(result);
135 }
136 
137 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)138 StatusOr<Literal> Compare<complex128>(const Shape& shape,
139                                       ComparisonDirection direction,
140                                       LiteralSlice lhs_literal,
141                                       LiteralSlice rhs_literal) {
142   std::function<bool(complex128, complex128)> compare_op;
143   switch (direction) {
144     case ComparisonDirection::kEq:
145       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
146         return lhs_el == rhs_el;
147       };
148       break;
149     case ComparisonDirection::kNe:
150       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
151         return lhs_el != rhs_el;
152       };
153       break;
154     default:
155       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
156                  << ComparisonDirectionToString(direction);
157   }
158 
159   Literal result(shape);
160   TF_RETURN_IF_ERROR(
161       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
162         return compare_op(lhs_literal.Get<complex128>(multi_index),
163                           rhs_literal.Get<complex128>(multi_index));
164       }));
165 
166   return std::move(result);
167 }
168 
169 }  // namespace
170 
171 // Note that unsupported types by the typed visitor does not necessarily imply
172 // the non-typed HloEvaluator (parent evaluator) would not support them either
173 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
174 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
175 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64 max_loop_iterations)176 HloEvaluator::HloEvaluator(int64 max_loop_iterations)
177     : max_loop_iterations_(max_loop_iterations) {
178   typed_visitors_[PRED] =
179       absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
180   typed_visitors_[U8] =
181       absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
182   typed_visitors_[U16] =
183       absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
184   typed_visitors_[U32] =
185       absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
186   typed_visitors_[U64] =
187       absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
188   typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
189   typed_visitors_[S16] =
190       absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
191   typed_visitors_[S32] =
192       absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
193   typed_visitors_[S64] =
194       absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
195   typed_visitors_[F16] =
196       absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
197   typed_visitors_[F32] =
198       absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
199   typed_visitors_[F64] =
200       absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
201   typed_visitors_[C64] =
202       absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
203   typed_visitors_[C128] =
204       absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
205 
206   // Most of the evaluator computations we use don't support BF16 (e.g.,
207   // std::ceil, std::tanh). To make evaluator work with BF16, we set all
208   // elementwise computations to be done in F32 and do BF16<->F32 conversion
209   // around the input and the output of the computations.
210   typed_visitors_[BF16] =
211       absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
212 
213   typed_visitors_[TUPLE] =
214       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
215         return Unimplemented(
216             "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
217       });
218   typed_visitors_[OPAQUE] =
219       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
220         return Unimplemented(
221             "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
222       });
223   typed_visitors_[TOKEN] =
224       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
225         return Unimplemented(
226             "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
227       });
228 }
229 
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)230 StatusOr<Literal> HloEvaluator::Evaluate(
231     const HloComputation& computation,
232     absl::Span<const Literal* const> arg_literals) {
233   CHECK(computation.parent() != nullptr);
234   XLA_VLOG_LINES(
235       2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
236 
237   if (arg_literals.size() != computation.num_parameters()) {
238     return InvalidArgument(
239         "Expected %d argument%s, but got %d.", computation.num_parameters(),
240         computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
241   }
242   for (int64 i = 0; i < arg_literals.size(); ++i) {
243     const auto& computation_shape =
244         computation.parameter_instruction(i)->shape();
245     const auto& arg_shape = arg_literals[i]->shape();
246     if (!ShapeUtil::Equal(computation_shape, arg_shape)) {
247       return InvalidArgument(
248           "Shape mismatch at parameter %d. Computation expected %s, but arg "
249           "was %s.",
250           i, ShapeUtil::HumanStringWithLayout(computation_shape),
251           ShapeUtil::HumanString(arg_shape));
252     }
253   }
254 
255   evaluated_.clear();
256   arg_literals_.clear();
257   for (const auto& literal_ptr : arg_literals) {
258     arg_literals_.push_back(&*literal_ptr);
259   }
260 
261   // Re-seed RNG, either from the configuration's seed or a monotonic
262   // per-evaluator seed (which prevents two evaluators from returning the same
263   // random sequence).
264   if (computation.parent()->config().seed()) {
265     seed_ = computation.parent()->config().seed();
266   } else {
267     // Start global_seed at a (true) random value.
268     static std::atomic<uint64> global_seed{std::random_device()()};
269     seed_ = global_seed.fetch_add(1);
270   }
271   engine_.seed(seed_);
272 
273   TF_RETURN_IF_ERROR(computation.Accept(this));
274   return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
275 }
276 
Evaluate(HloInstruction * instruction)277 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
278   if (instruction->opcode() == HloOpcode::kParameter) {
279     return tensorflow::errors::FailedPrecondition(
280         "Cannot evaluate a parameter.");
281   }
282   if (!hlo_query::AllOperandsAreConstants(*instruction)) {
283     return tensorflow::errors::FailedPrecondition(
284         "Not all operands are constants.");
285   }
286 
287   arg_literals_.clear();
288   evaluated_.clear();
289 
290   TF_RETURN_IF_ERROR(Preprocess(instruction));
291   TF_RETURN_IF_ERROR(instruction->Visit(this));
292   TF_RETURN_IF_ERROR(Postprocess(instruction));
293   return GetEvaluatedLiteralFor(instruction).Clone();
294 }
295 
TryEvaluate(HloInstruction * instruction,Literal * result)296 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
297   CHECK(result != nullptr);
298   auto result_or = Evaluate(instruction);
299   if (!result_or.ok()) {
300     VLOG(1) << "TryEvaluate failed:" << result_or.status();
301     return false;
302   }
303 
304   *result = result_or.ConsumeValueOrDie();
305   return true;
306 }
307 
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)308 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
309     const HloInstruction* instruction,
310     const std::unordered_map<const HloInstruction*, const Literal*>&
311         substitutions) {
312   std::vector<std::unique_ptr<HloInstruction>> owned_operands;
313   for (const HloInstruction* operand : instruction->operands()) {
314     auto it = substitutions.find(operand);
315     if (it == substitutions.end()) {
316       owned_operands.push_back(operand->Clone());
317     } else {
318       owned_operands.push_back(
319           HloInstruction::CreateConstant(it->second->Clone()));
320     }
321   }
322 
323   std::vector<HloInstruction*> operands;
324   operands.reserve(owned_operands.size());
325   for (auto& operand : owned_operands) {
326     operands.push_back(operand.get());
327   }
328 
329   std::unique_ptr<HloInstruction> cloned_instruction =
330       instruction->CloneWithNewOperands(instruction->shape(), operands);
331   auto result = Evaluate(cloned_instruction.get());
332 
333   return result;
334 }
335 
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)336 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
337     HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
338   std::unique_ptr<HloInstruction> lhs_instr =
339       HloInstruction::CreateConstant(lhs.Clone());
340   std::unique_ptr<HloInstruction> rhs_instr =
341       HloInstruction::CreateConstant(rhs.Clone());
342 
343   std::unique_ptr<HloInstruction> cloned_instruction =
344       HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
345                                    rhs_instr.get());
346   auto result = Evaluate(cloned_instruction.get());
347 
348   return result;
349 }
350 
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)351 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
352     HloOpcode opcode, const Literal& operand) {
353   std::unique_ptr<HloInstruction> operand_instr =
354       HloInstruction::CreateConstant(operand.Clone());
355 
356   std::unique_ptr<HloInstruction> cloned_instruction =
357       HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
358   auto result = Evaluate(cloned_instruction.get());
359 
360   return result;
361 }
362 
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)363 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
364     const DotDimensionNumbers& dim_numbers,
365     const PrecisionConfig& precision_config, const Literal& lhs,
366     const Literal& rhs) {
367   std::unique_ptr<HloInstruction> lhs_instr =
368       HloInstruction::CreateConstant(lhs.Clone());
369   std::unique_ptr<HloInstruction> rhs_instr =
370       HloInstruction::CreateConstant(rhs.Clone());
371 
372   TF_ASSIGN_OR_RETURN(
373       Shape dot_shape,
374       ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
375 
376   std::unique_ptr<HloInstruction> cloned_instruction =
377       HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
378                                 dim_numbers, precision_config);
379   return Evaluate(cloned_instruction.get());
380 }
381 
HandleBitcast(HloInstruction * bitcast)382 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
383   const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
384   Literal result(bitcast->shape());
385   TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
386   memcpy(result.untyped_data(), operand_literal.untyped_data(),
387          operand_literal.size_bytes());
388   evaluated_[bitcast] = std::move(result);
389   return Status::OK();
390 }
391 
HandleGetDimensionSize(HloInstruction * get_dimension_size)392 Status HloEvaluator::HandleGetDimensionSize(
393     HloInstruction* get_dimension_size) {
394   HloInstruction* operand = get_dimension_size->mutable_operand(0);
395   int64 dim = get_dimension_size->dimension();
396   if (dynamic_dimension_inference_ == nullptr) {
397     return InvalidArgument(
398         "Evaluator cannot evaluate get_dimension_size without "
399         "set_dynamic_dimension_inference.");
400   }
401   HloInstruction* dynamic_size =
402       dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
403   if (dynamic_size != nullptr) {
404     evaluated_[get_dimension_size] =
405         GetEvaluatedLiteralFor(dynamic_size).Clone();
406     return Status::OK();
407   }
408 
409   const Shape& shape = get_dimension_size->operand(0)->shape();
410   Literal output(ShapeUtil::MakeShape(U32, {}));
411   output.PopulateWithValue(
412       static_cast<uint32>(shape.dimensions(get_dimension_size->dimension())));
413   evaluated_[get_dimension_size] = std::move(output);
414   return Status::OK();
415 }
416 
HandleParameter(HloInstruction * parameter)417 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
418   // Nothing to do other than sanity checks. Parameters' values are stored in
419   // arg_literals_.
420   CHECK_LT(parameter->parameter_number(), arg_literals_.size());
421 
422 #ifndef NDEBUG
423   const Literal* input_literal = arg_literals_[parameter->parameter_number()];
424   VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
425   DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()))
426       << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape())
427       << ", but input literal shape is: "
428       << ShapeUtil::HumanString(input_literal->shape());
429 #endif
430 
431   return Status::OK();
432 }
433 
HandleConstant(HloInstruction *)434 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
435 
HandleReshape(HloInstruction * reshape)436 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
437   TF_ASSIGN_OR_RETURN(
438       evaluated_[reshape],
439       GetEvaluatedLiteralFor(reshape->operand(0))
440           .Reshape(AsInt64Slice(reshape->shape().dimensions())));
441   return Status::OK();
442 }
443 
HandleTranspose(HloInstruction * transpose)444 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
445   evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
446                               .Transpose(transpose->dimensions());
447   return Status::OK();
448 }
449 
HandleConcatenate(HloInstruction * concatenate)450 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
451   absl::Span<HloInstruction* const> operands(concatenate->operands());
452   // The result concatenate dimension is going to be the sum of all
453   // concatenate dimensions of the operands taking part of the operation.
454   const Shape& reference_shape = operands[0]->shape();
455   CHECK(reference_shape.IsArray());
456   const int64 rank = reference_shape.rank();
457   const int64 concat_dim = concatenate->dimensions()[0];
458   CHECK_GE(concat_dim, 0);
459   CHECK_LT(concat_dim, rank);
460 
461   DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
462                                     reference_shape.dimensions().end());
463 
464   for (int64 i = 1; i < operands.size(); ++i) {
465     const Shape& operand_shape = operands[i]->shape();
466     CHECK(operand_shape.IsArray());
467     // Accumulate the concat dimension from all tensors taking part to the
468     // operation.
469     concat_dimensions[concat_dim] +=
470         ShapeUtil::GetDimension(operand_shape, concat_dim);
471   }
472 
473   auto result_literal = LiteralUtil::CreateFromDimensions(
474       reference_shape.element_type(), concat_dimensions);
475   DimensionVector source_indices(rank, 0);
476   DimensionVector dest_indices(concat_dimensions.size(), 0);
477 
478   for (auto operand : operands) {
479     const Shape& operand_shape = operand->shape();
480     TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
481         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
482         AsInt64Slice(operand_shape.dimensions())));
483     dest_indices[concat_dim] +=
484         ShapeUtil::GetDimension(operand_shape, concat_dim);
485   }
486 
487   evaluated_[concatenate] = std::move(result_literal);
488   return Status::OK();
489 }
490 
HandleIsFinite(HloInstruction * is_finite)491 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
492   auto operand = is_finite->operand(0);
493   auto elem_ty = operand->shape().element_type();
494   switch (elem_ty) {
495     case PRED:
496     case TUPLE:
497     case OPAQUE:
498     case TOKEN:
499     case S8:
500     case S16:
501     case S32:
502     case S64:
503     case U8:
504     case U16:
505     case U32:
506     case U64:
507     case C64:
508     case C128:
509     // Explicitly enumerate all types in this switch so that when we add a new
510     // type, we'll get a compile error here.
511     case PRIMITIVE_TYPE_INVALID:
512     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
513     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
514       return InvalidArgument(
515           "expected element type in shape to be floating point, but "
516           "got: %s",
517           PrimitiveType_Name(elem_ty));
518 
519     case F16: {
520       auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
521           is_finite,
522           [](Eigen::half elem_operand) {
523             return std::isfinite(static_cast<float>(elem_operand));
524           },
525           GetEvaluatedLiteralFor(operand));
526       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
527       break;
528     }
529     case BF16: {
530       auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
531           is_finite,
532           [](bfloat16 elem_operand) {
533             return std::isfinite(static_cast<float>(elem_operand));
534           },
535           GetEvaluatedLiteralFor(operand));
536       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
537       break;
538     }
539     case F32: {
540       auto result_or = ElementWiseUnaryOpImpl<bool, float>(
541           is_finite,
542           [](float elem_operand) { return std::isfinite(elem_operand); },
543           GetEvaluatedLiteralFor(operand));
544       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
545       break;
546     }
547     case F64: {
548       auto result_or = ElementWiseUnaryOpImpl<bool, double>(
549           is_finite,
550           [](double elem_operand) { return std::isfinite(elem_operand); },
551           GetEvaluatedLiteralFor(operand));
552       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
553       break;
554     }
555   }
556 
557   return Status::OK();
558 }
559 
HandleReal(HloInstruction * real)560 Status HloEvaluator::HandleReal(HloInstruction* real) {
561   auto operand = real->operand(0);
562   switch (operand->shape().element_type()) {
563     case BF16: {
564       auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
565           real, [](bfloat16 elem_operand) { return elem_operand; },
566           GetEvaluatedLiteralFor(operand));
567       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
568       break;
569     }
570     case C64: {
571       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
572           real, [](complex64 elem_operand) { return std::real(elem_operand); },
573           GetEvaluatedLiteralFor(operand));
574       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
575       break;
576     }
577     case C128: {
578       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
579           real, [](complex128 elem_operand) { return std::real(elem_operand); },
580           GetEvaluatedLiteralFor(operand));
581       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
582       break;
583     }
584     case F16: {
585       auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
586           real, [](Eigen::half elem_operand) { return elem_operand; },
587           GetEvaluatedLiteralFor(operand));
588       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
589       break;
590     }
591     case F32: {
592       auto result_or = ElementWiseUnaryOpImpl<float, float>(
593           real, [](float elem_operand) { return elem_operand; },
594           GetEvaluatedLiteralFor(operand));
595       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
596       break;
597     }
598     case F64: {
599       auto result_or = ElementWiseUnaryOpImpl<double, double>(
600           real, [](double elem_operand) { return elem_operand; },
601           GetEvaluatedLiteralFor(operand));
602       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
603       break;
604     }
605     default:
606       LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
607                  << PrimitiveType_Name(operand->shape().element_type());
608   }
609 
610   return Status::OK();
611 }
612 
HandleImag(HloInstruction * imag)613 Status HloEvaluator::HandleImag(HloInstruction* imag) {
614   auto operand = imag->operand(0);
615   switch (operand->shape().element_type()) {
616     case C64: {
617       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
618           imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
619           GetEvaluatedLiteralFor(imag->operand(0)));
620 
621       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
622       break;
623     }
624     case C128: {
625       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
626           imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
627           GetEvaluatedLiteralFor(imag->operand(0)));
628 
629       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
630       break;
631     }
632     default:
633       LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
634                  << PrimitiveType_Name(operand->shape().element_type());
635   }
636 
637   return Status::OK();
638 }
639 
HandleComplex(HloInstruction * complex)640 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
641   const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
642   const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
643   TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
644 
645   Literal result(complex->shape());
646   switch (complex->shape().element_type()) {
647     case C64: {
648       TF_RETURN_IF_ERROR(
649           result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
650             return std::complex<float>(real.Get<float>(multi_index),
651                                        imag.Get<float>(multi_index));
652           }));
653       break;
654     }
655     case C128: {
656       TF_RETURN_IF_ERROR(
657           result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
658             return std::complex<float>(real.Get<double>(multi_index),
659                                        imag.Get<double>(multi_index));
660           }));
661       break;
662     }
663     default:
664       LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
665                  << PrimitiveType_Name(complex->shape().element_type());
666   }
667 
668   evaluated_[complex] = std::move(result);
669   return Status::OK();
670 }
671 
HandleCompare(HloInstruction * compare)672 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
673   ComparisonDirection direction = compare->comparison_direction();
674   auto lhs = compare->operand(0);
675   auto rhs = compare->operand(1);
676   DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
677          ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
678 
679   TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
680 
681   const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
682   const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
683 
684   // Note here we switch on the operand's type.
685   switch (lhs->shape().element_type()) {
686     case PRED: {
687       TF_ASSIGN_OR_RETURN(
688           evaluated_[compare],
689           Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
690     } break;
691     case U8: {
692       TF_ASSIGN_OR_RETURN(evaluated_[compare],
693                           Compare<uint8>(compare->shape(), direction,
694                                          lhs_literal, rhs_literal));
695     } break;
696     case U16: {
697       TF_ASSIGN_OR_RETURN(evaluated_[compare],
698                           Compare<uint16>(compare->shape(), direction,
699                                           lhs_literal, rhs_literal));
700     } break;
701     case U32: {
702       TF_ASSIGN_OR_RETURN(evaluated_[compare],
703                           Compare<uint32>(compare->shape(), direction,
704                                           lhs_literal, rhs_literal));
705     } break;
706     case U64: {
707       TF_ASSIGN_OR_RETURN(evaluated_[compare],
708                           Compare<uint64>(compare->shape(), direction,
709                                           lhs_literal, rhs_literal));
710     } break;
711     case S8: {
712       TF_ASSIGN_OR_RETURN(
713           evaluated_[compare],
714           Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
715     } break;
716     case S16: {
717       TF_ASSIGN_OR_RETURN(evaluated_[compare],
718                           Compare<int16>(compare->shape(), direction,
719                                          lhs_literal, rhs_literal));
720     } break;
721     case S32: {
722       TF_ASSIGN_OR_RETURN(evaluated_[compare],
723                           Compare<int32>(compare->shape(), direction,
724                                          lhs_literal, rhs_literal));
725     } break;
726     case S64: {
727       TF_ASSIGN_OR_RETURN(evaluated_[compare],
728                           Compare<int64>(compare->shape(), direction,
729                                          lhs_literal, rhs_literal));
730     } break;
731     case F16: {
732       TF_ASSIGN_OR_RETURN(
733           evaluated_[compare],
734           Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
735     } break;
736     case BF16: {
737       TF_ASSIGN_OR_RETURN(evaluated_[compare],
738                           Compare<bfloat16>(compare->shape(), direction,
739                                             lhs_literal, rhs_literal));
740     } break;
741     case F32: {
742       TF_ASSIGN_OR_RETURN(evaluated_[compare],
743                           Compare<float>(compare->shape(), direction,
744                                          lhs_literal, rhs_literal));
745     } break;
746     case F64: {
747       TF_ASSIGN_OR_RETURN(evaluated_[compare],
748                           Compare<double>(compare->shape(), direction,
749                                           lhs_literal, rhs_literal));
750     } break;
751     case C64: {
752       TF_ASSIGN_OR_RETURN(evaluated_[compare],
753                           Compare<complex64>(compare->shape(), direction,
754                                              lhs_literal, rhs_literal));
755     } break;
756     case C128: {
757       TF_ASSIGN_OR_RETURN(evaluated_[compare],
758                           Compare<complex128>(compare->shape(), direction,
759                                               lhs_literal, rhs_literal));
760     } break;
761     default:
762       LOG(FATAL) << "HandleCompare: unknown primitive type: "
763                  << PrimitiveType_Name(lhs->shape().element_type());
764   }
765 
766   return Status::OK();
767 }
768 
HandleTuple(HloInstruction * tuple)769 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
770   std::vector<const Literal*> operand_literals;
771   for (auto operand : tuple->operands()) {
772     operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
773   }
774 
775   evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
776   return Status::OK();
777 }
778 
779 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
780 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)781 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
782     const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
783   int64 output_rank = output_shape.dimensions_size();
784   std::vector<int64> index_base(output_rank, 0);
785   std::vector<int64> index_count;
786   index_count.reserve(output_rank);
787   for (int64 i = 0; i < output_rank; i++) {
788     bool is_output_batch_dim =
789         !absl::c_binary_search(dim_numbers.offset_dims(), i);
790     index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
791   }
792 
793   return {std::move(index_base), std::move(index_count),
794           std::vector<int64>(output_rank, 1)};
795 }
796 
797 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
798 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64 output_rank,absl::Span<const int64> slice_sizes,const GatherDimensionNumbers & dim_numbers)799 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
800     int64 output_rank, absl::Span<const int64> slice_sizes,
801     const GatherDimensionNumbers& dim_numbers) {
802   std::vector<int64> index_base(output_rank, 0);
803   std::vector<int64> index_count(output_rank, 1);
804   int64 slice_sizes_idx = 0;
805   for (int64 i = 0; i < output_rank; i++) {
806     bool is_output_window_dim =
807         absl::c_binary_search(dim_numbers.offset_dims(), i);
808     if (is_output_window_dim) {
809       while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
810                                    slice_sizes_idx)) {
811         slice_sizes_idx++;
812       }
813       index_count[i] = slice_sizes[slice_sizes_idx++];
814     }
815   }
816 
817   return {std::move(index_base), std::move(index_count),
818           std::vector<int64>(output_rank, 1)};
819 }
820 
821 // This functor computes the contribution of start_indices to an input index
822 // corresponding to an output index.  That is, given an output index I, it picks
823 // out the batch indices in I and uses them to look up a starting index, G, from
824 // the start indices tensor, and expands G into the input space according to
825 // start_index_map.
826 class OutputBatchIndexToInputIndex {
827  public:
828   // The constructor does some setup work that is amortized across all
829   // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)830   explicit OutputBatchIndexToInputIndex(
831       const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
832       const Shape& output_shape, const Literal* start_indices)
833       : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
834     for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
835       output_dim_is_batch_dims_.push_back(
836           !absl::c_binary_search(dim_numbers_.offset_dims(), i));
837     }
838 
839     for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
840       int64 index_of_input_dim_in_index_vector =
841           std::distance(dim_numbers_.start_index_map().begin(),
842                         absl::c_find(dim_numbers_.start_index_map(), i));
843       if (index_of_input_dim_in_index_vector ==
844           dim_numbers_.start_index_map_size()) {
845         input_dim_value_to_index_vector_.push_back(-1);
846       } else {
847         input_dim_value_to_index_vector_.push_back(
848             index_of_input_dim_in_index_vector);
849       }
850     }
851 
852     index_vector_index_.resize(start_indices_.shape().dimensions_size());
853     input_index_.resize(input_shape.dimensions_size());
854     int64 index_vector_size =
855         start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
856     index_vector_.resize(index_vector_size);
857   }
858 
859   // Returns the contribution of start_indices to the input index corresponding
860   // to output_index.  See gather_inner_loop_body.
861   //
862   // This is conceptually  a stateless transformation from output_index to the
863   // gather input index, but:
864   //
865   //  - Instead of allocating memory to represent the gather input index on
866   //    every invocation we reuse the same storage for the result
867   //    (input_index_), mutating it in place.
868   //  - Instead of allocating buffers for temporary values like
869   //    index_vector_index_ and index_vector on every invocation, we reuse the
870   //    same storage for all invocations.
871   //
872   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)873   StatusOr<absl::Span<const int64>> operator()(
874       absl::Span<const int64> output_index) {
875     PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
876     TF_RETURN_IF_ERROR(FetchIndexVector());
877     PropagateIndexVectorToInputIndex();
878     return absl::Span<const int64>(input_index_);
879   }
880 
881  private:
882   // Propagates the batch dimensions from the output index into
883   // index_vector_index_ by mutating index_vector_index_ in place.  Does not
884   // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
885   // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64> output_index)886   void PropagateOutputIndexGatherDimsToIndexVectorIndex(
887       absl::Span<const int64> output_index) {
888     int64 index_vector_index_i = 0;
889     for (int64 i = 0, e = output_index.size(); i < e; i++) {
890       if (!output_dim_is_batch_dims_[i]) {
891         continue;
892       }
893 
894       if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
895         index_vector_index_i++;
896       }
897 
898       index_vector_index_[index_vector_index_i++] = output_index[i];
899     }
900   }
901 
902   // Populates index_vector_ by iterating over start_indices_ according to
903   // index_vector_index_.
FetchIndexVector()904   Status FetchIndexVector() {
905     int64 index_vector_dim = dim_numbers_.index_vector_dim();
906     for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
907       index_vector_index_[index_vector_dim] = i;
908       TF_ASSIGN_OR_RETURN(index_vector_[i],
909                           start_indices_.GetIntegralAsS64(index_vector_index_));
910     }
911     return Status::OK();
912   }
913 
914   // Populates input_index_.
PropagateIndexVectorToInputIndex()915   void PropagateIndexVectorToInputIndex() {
916     for (int64 i = 0, e = input_index_.size(); i < e; i++) {
917       if (input_dim_value_to_index_vector_[i] != -1) {
918         input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
919       }
920 
921       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
922       // remains 0, as set by the constructor.
923     }
924   }
925 
926   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
927   // the input index from the index vector.  See
928   // PropagateIndexVectorToInputIndex.
929   std::vector<int64> input_dim_value_to_index_vector_;
930 
931   // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
932   // dimension.
933   std::vector<bool> output_dim_is_batch_dims_;
934 
935   // The buffer into which we construct an index into start_indices_ to fetch
936   // the index vector.
937   std::vector<int64> index_vector_index_;
938 
939   // The index vector fetched from start_indices_.
940   std::vector<int64> index_vector_;
941 
942   // The result computed by this functor.  operator() returns a Span into
943   // this vector.
944   std::vector<int64> input_index_;
945 
946   const GatherDimensionNumbers& dim_numbers_;
947   const Literal& start_indices_;
948 };
949 
950 // This functor computes the contribution of the offset indices in an output
951 // index to an input index.  That is, given an output index I it picks out the
952 // output offset indices in I and expands it into an index into the input shape.
953 class OutputOffsetIndexToInputIndex {
954  public:
955   // The constructor does some setup work that is amortized across all
956   // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)957   explicit OutputOffsetIndexToInputIndex(
958       const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
959       const Shape& output_shape) {
960     std::vector<int64> window_index_to_output_index;
961     int64 output_index_count = 0;
962     for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
963       if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
964         window_index_to_output_index.push_back(output_index_count++);
965       } else {
966         output_index_count++;
967       }
968     }
969 
970     int64 window_dim_count = 0;
971     for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
972       if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
973         input_dim_value_to_output_index_.push_back(-1);
974       } else {
975         input_dim_value_to_output_index_.push_back(
976             window_index_to_output_index[window_dim_count++]);
977       }
978     }
979 
980     input_index_.resize(input_shape.dimensions_size());
981   }
982 
983   // Returns the contribution of the window indices to the input index
984   // corresponding to output_index.  See gather_inner_loop_body.
985   //
986   // This is conceptually a stateless transformation from output_index to the
987   // window input index, but instead of allocating memory to represent the
988   // gather input index on every invocation we reuse the same storage for the
989   // result (input_index_), mutating it in place.
990   //
991   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)992   StatusOr<absl::Span<const int64>> operator()(
993       absl::Span<const int64> output_index) {
994     PropagateOutputIndexWindowDimsToInputIndex(output_index);
995     return absl::Span<const int64>(input_index_);
996   }
997 
998   // Returns for a given 'input_dim' the corresponding output dimension index,
999   // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64 input_dim)1000   int64 input_dim_value_to_output_index(int64 input_dim) {
1001     return input_dim_value_to_output_index_[input_dim];
1002   }
1003 
1004  private:
1005   // Propagates window dimensions from the output index to input_index_ by
1006   // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64> output_index)1007   void PropagateOutputIndexWindowDimsToInputIndex(
1008       absl::Span<const int64> output_index) {
1009     for (int64 i = 0, e = input_index_.size(); i < e; i++) {
1010       if (input_dim_value_to_output_index_[i] != -1) {
1011         input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
1012       }
1013 
1014       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1015       // remains 0, as set by the constructor.
1016     }
1017   }
1018 
1019   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1020   // the input index from the output index. See
1021   // PropagateOutputIndexWindowDimsToInputIndex.
1022   std::vector<int64> input_dim_value_to_output_index_;
1023 
1024   // The result computed by this functor.  operator() returns a Span into
1025   // this vector.
1026   std::vector<int64> input_index_;
1027 };
1028 
1029 // Rehapes the gather indices input to have a trailing degenerate `1` dimension
1030 // if necessary.  Hands over the ownership of the newly created literal (if
1031 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64 index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)1032 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
1033     int64 index_vector_dim, const Literal& start_indices,
1034     Literal* reshaped_start_indices) {
1035   if (start_indices.shape().dimensions_size() != index_vector_dim) {
1036     return std::cref(start_indices);
1037   }
1038 
1039   std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
1040                                start_indices.shape().dimensions().end());
1041   new_shape.push_back(1);
1042   TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
1043                       start_indices.Reshape(new_shape));
1044   return std::cref(*reshaped_start_indices);
1045 }
1046 
HandleGather(HloInstruction * gather)1047 Status HloEvaluator::HandleGather(HloInstruction* gather) {
1048   Literal result = Literal::CreateFromShape(gather->shape());
1049   const Shape& shape = gather->shape();
1050   const GatherDimensionNumbers& dim_numbers =
1051       gather->gather_dimension_numbers();
1052   const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
1053   Literal reshaped_start_indices;
1054   TF_ASSIGN_OR_RETURN(
1055       const Literal& start_indices,
1056       ReshapedGatherIndices(dim_numbers.index_vector_dim(),
1057                             GetEvaluatedLiteralFor(gather->operand(1)),
1058                             &reshaped_start_indices));
1059 
1060   // We iterate over the gather dimensions in the output shape in an outer loop
1061   // nest, and iterate over the window dimensions in the output shape in an
1062   // inner loop nest.
1063 
1064   ShapeUtil::IndexIterationSpace start_indices_iteration_space =
1065       IterationSpaceForOutputBatchIndices(shape, dim_numbers);
1066   ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
1067       IterationSpaceForOutputOffsetIndices(
1068           shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
1069 
1070   // Scratch buffers that hold an index in the output shape and the
1071   // corresponding index in the input shape.
1072   std::vector<int64> input_index(operand.shape().dimensions_size());
1073   std::vector<int64> output_index(gather->shape().dimensions_size());
1074   std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
1075 
1076   OutputBatchIndexToInputIndex output_batch_index_to_input_index(
1077       &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1078       /*output_shape=*/shape, &start_indices);
1079   OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
1080       gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1081       /*output_shape=*/shape);
1082 
1083   const Shape& operand_shape = operand.shape();
1084 
1085   auto gather_inner_loop_body =
1086       [&](absl::Span<const int64> output_window_index,
1087           absl::Span<const int64> input_gather_index,
1088           absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1089     TF_ASSIGN_OR_RETURN(
1090         absl::Span<const int64> input_window_index,
1091         output_offset_index_to_input_index(output_window_index));
1092     for (int i = 0, e = output_index.size(); i < e; i++) {
1093       output_index[i] = output_gather_index[i] + output_window_index[i];
1094       DCHECK_LT(output_index[i], shape.dimensions(i));
1095     }
1096     for (int i = 0, e = input_gather_index.size(); i < e; i++) {
1097       int64 output_dim =
1098           output_offset_index_to_input_index.input_dim_value_to_output_index(i);
1099       // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
1100       // we set the iteration index to 0, so for the purpose of the following
1101       // calculations we can consider the output dimension size to be 1.
1102       int64 output_dim_size =
1103           output_dim == -1 ? 1 : shape.dimensions(output_dim);
1104       // Clamp the gather index so that the gather region fits in the operand.
1105       // input_index_clamped[i] = clamp(input_gather_index[i], 0,
1106       //                                       operand_shape.dimensions(i) -
1107       //                                       output_dim_size);
1108       input_index_clamped[i] =
1109           std::min(operand_shape.dimensions(i) - output_dim_size,
1110                    std::max(0LL, input_gather_index[i]));
1111     }
1112     for (int i = 0, e = input_index.size(); i < e; i++) {
1113       input_index[i] = input_index_clamped[i] + input_window_index[i];
1114       DCHECK_GE(input_index[i], 0);
1115       DCHECK_LT(input_index[i], operand_shape.dimensions(i));
1116     }
1117     TF_RETURN_IF_ERROR(
1118         result.CopyElementFrom(operand, input_index, output_index));
1119     return true;
1120   };
1121 
1122   auto gather_outer_loop_body =
1123       [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1124     TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
1125                         output_batch_index_to_input_index(output_gather_index));
1126     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1127         shape, offset_indices_iteration_space,
1128         std::bind(gather_inner_loop_body, std::placeholders::_1,
1129                   input_gather_index, output_gather_index)));
1130     return true;
1131   };
1132 
1133   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1134       shape, start_indices_iteration_space, gather_outer_loop_body));
1135   evaluated_[gather] = std::move(result);
1136   return Status::OK();
1137 }
1138 
HandleBroadcast(HloInstruction * broadcast)1139 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
1140   const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
1141 
1142   TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
1143       << "broadcast dimensions is of size: " << broadcast->dimensions().size()
1144       << " and rank of operand_to_broadcast is: " << operand.shape().rank();
1145   // Checks that operand's dimensions are the same as the broadcast's
1146   // dimensions along the dimensions to be broadcasted.
1147   for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
1148     auto operand_dim_size = operand.shape().dimensions(i);
1149     auto broadcast_dim_size =
1150         broadcast->shape().dimensions(broadcast->dimensions(i));
1151     TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
1152         "Operand dimension %d is broadcast to output dimension %d, but the "
1153         "sizes of these two dims do not match (%d vs %d): %s",
1154         i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
1155         broadcast->ToString());
1156   }
1157 
1158   TF_ASSIGN_OR_RETURN(
1159       evaluated_[broadcast],
1160       operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
1161 
1162   return Status::OK();
1163 }
1164 
HandleAfterAll(HloInstruction * after_all)1165 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
1166   evaluated_[after_all] = LiteralUtil::CreateToken();
1167   return Status::OK();
1168 }
1169 
HandleAddDependency(HloInstruction * add_dependency)1170 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
1171   // AddDedendency just forwards its zero-th operand.
1172   evaluated_[add_dependency] =
1173       GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
1174   return Status::OK();
1175 }
1176 
HandleGetTupleElement(HloInstruction * get_tuple_element)1177 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
1178   const auto result_shape = get_tuple_element->shape();
1179   const int64 index = get_tuple_element->tuple_index();
1180 
1181   auto operand = get_tuple_element->operand(0);
1182   TF_ASSIGN_OR_RETURN(
1183       auto inferred_return_shape,
1184       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
1185   TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1186       << "return shape set to: " << ShapeUtil::HumanString(result_shape)
1187       << " but is inferred to be: "
1188       << ShapeUtil::HumanString(inferred_return_shape);
1189 
1190   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1191 
1192   evaluated_[get_tuple_element] =
1193       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
1194   return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
1195                                                 /*dest_shape_index=*/{},
1196                                                 /*src_shape_index=*/{index});
1197 }
1198 
HandleCopy(HloInstruction * copy)1199 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
1200   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
1201   evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
1202   return Status::OK();
1203 }
1204 
HandleCall(HloInstruction * call)1205 Status HloEvaluator::HandleCall(HloInstruction* call) {
1206   auto* computation = call->to_apply();
1207   auto operands = call->operands();
1208 
1209   std::vector<const Literal*> arg_literals;
1210   arg_literals.reserve(operands.size());
1211   for (auto operand : operands) {
1212     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1213     arg_literals.push_back(&arg_literal);
1214   }
1215 
1216   HloEvaluator embedded_evaluator;
1217   embedded_evaluator.set_dynamic_dimension_inference(
1218       dynamic_dimension_inference_);
1219   TF_ASSIGN_OR_RETURN(Literal result,
1220                       embedded_evaluator.Evaluate(*computation, arg_literals));
1221 
1222   evaluated_[call] = std::move(result);
1223   return Status::OK();
1224 }
1225 
HandleFusion(HloInstruction * fusion)1226 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
1227   HloModuleConfig config;
1228   // Attach cloned computation to an empty HLO module so the existing ones are
1229   // not modified.
1230   HloModule empty_hlo_module("EmptyModuleForFusion", config);
1231   HloCloneContext context(&empty_hlo_module);
1232   auto cloned_fused_computation =
1233       fusion->fused_instructions_computation()->Clone(
1234           /*suffix=*/"clone_with_layout", &context);
1235   for (auto* instruction : cloned_fused_computation->instructions()) {
1236     if (!LayoutUtil::HasLayout(instruction->shape())) {
1237       LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
1238     }
1239   }
1240   auto readded_computation =
1241       empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
1242 
1243   auto operands = fusion->operands();
1244   std::vector<const Literal*> arg_literals;
1245   arg_literals.reserve(operands.size());
1246   for (auto operand : operands) {
1247     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1248     arg_literals.push_back(&arg_literal);
1249   }
1250 
1251   HloEvaluator embedded_evaluator;
1252   embedded_evaluator.set_dynamic_dimension_inference(
1253       dynamic_dimension_inference_);
1254   TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
1255                                           *readded_computation, arg_literals));
1256 
1257   evaluated_[fusion] = std::move(result);
1258   return Status::OK();
1259 }
1260 
HandleConditional(HloInstruction * conditional)1261 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
1262   const auto& branch_index_literal =
1263       GetEvaluatedLiteralFor(conditional->operand(0));
1264   int branch_index;
1265   if (conditional->operand(0)->shape().element_type() == PRED) {
1266     branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
1267   } else {
1268     branch_index = branch_index_literal.Get<int32>({});
1269     if (branch_index < 0 || branch_index >= conditional->branch_count()) {
1270       branch_index = conditional->branch_count() - 1;
1271     }
1272   }
1273   const auto& branch_computation_arg =
1274       GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
1275 
1276   HloEvaluator embedded_evaluator;
1277   embedded_evaluator.set_dynamic_dimension_inference(
1278       dynamic_dimension_inference_);
1279   TF_ASSIGN_OR_RETURN(Literal result,
1280                       embedded_evaluator.Evaluate(
1281                           *conditional->branch_computation(branch_index),
1282                           {&branch_computation_arg}));
1283 
1284   evaluated_[conditional] = std::move(result);
1285   return Status::OK();
1286 }
1287 
HandleSelect(HloInstruction * select)1288 Status HloEvaluator::HandleSelect(HloInstruction* select) {
1289   const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
1290   const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
1291   const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
1292 
1293   // If predicate is of scalar type, no element-wise selection would be needed.
1294   if (ShapeUtil::IsScalar(pred.shape())) {
1295     if (pred.Get<bool>({})) {
1296       evaluated_[select] = on_true.Clone();
1297     } else {
1298       evaluated_[select] = on_false.Clone();
1299     }
1300     return Status::OK();
1301   }
1302 
1303   return DefaultAction(select);
1304 }
1305 
HandleTupleSelect(HloInstruction * tuple_select)1306 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
1307   const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
1308   const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
1309   const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
1310 
1311   if (pred.Get<bool>({})) {
1312     evaluated_[tuple_select] = on_true.Clone();
1313   } else {
1314     evaluated_[tuple_select] = on_false.Clone();
1315   }
1316   return Status::OK();
1317 }
1318 
HandleWhile(HloInstruction * while_hlo)1319 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
1320   HloComputation* cond_comp = while_hlo->while_condition();
1321   HloComputation* body_comp = while_hlo->while_body();
1322   // Initialize the loop carried valued with the input to the While instruction.
1323   auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
1324   bool keep_going = true;
1325   int64 iteration_count = 0;
1326   HloEvaluator cond_evaluator(max_loop_iterations_);
1327   cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
1328   HloEvaluator loop_body_evaluator(max_loop_iterations_);
1329   loop_body_evaluator.set_dynamic_dimension_inference(
1330       dynamic_dimension_inference_);
1331   while (keep_going) {
1332     if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
1333       return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
1334                              while_hlo->name(), max_loop_iterations_);
1335     }
1336     TF_ASSIGN_OR_RETURN(auto cond_val,
1337                         cond_evaluator.Evaluate(*cond_comp, {&lcv}));
1338     keep_going = cond_val.GetFirstElement<bool>();
1339     if (keep_going) {
1340       TF_ASSIGN_OR_RETURN(auto body_val,
1341                           loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
1342       VLOG(3) << "Loop iteration result: " << body_val.ToString();
1343       lcv = std::move(body_val);
1344       cond_evaluator.ResetVisitStates();
1345       loop_body_evaluator.ResetVisitStates();
1346     }
1347   }
1348   evaluated_[while_hlo] = std::move(lcv);
1349   return Status::OK();
1350 }
1351 
1352 namespace {
1353 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar)1354 Literal ExtractLiteralFromIndexPositions(const Literal& from,
1355                                          absl::Span<int64 const> indices,
1356                                          bool extract_as_scalar) {
1357   if (extract_as_scalar) {
1358     return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
1359   }
1360   // We use a InlinedVector here because we need to convert it to an
1361   // absl::Span later, and this would not work with std::vector<bool>.
1362   absl::InlinedVector<NativeT, 10> values;
1363   for (int64 index : indices) {
1364     values.push_back(from.Get<NativeT>({index}));
1365   }
1366   return LiteralUtil::CreateR1<NativeT>(values);
1367 }
1368 
ExtractFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar=false)1369 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
1370                                             absl::Span<int64 const> indices,
1371                                             bool extract_as_scalar = false) {
1372   if (extract_as_scalar) {
1373     CHECK_EQ(indices.size(), 1);
1374   }
1375   PrimitiveType type = from.shape().element_type();
1376   switch (type) {
1377     case PRED: {
1378       return ExtractLiteralFromIndexPositions<bool>(from, indices,
1379                                                     extract_as_scalar);
1380     }
1381     case U8: {
1382       return ExtractLiteralFromIndexPositions<uint8>(from, indices,
1383                                                      extract_as_scalar);
1384     }
1385     case S8: {
1386       return ExtractLiteralFromIndexPositions<int8>(from, indices,
1387                                                     extract_as_scalar);
1388     }
1389     case BF16: {
1390       return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
1391                                                         extract_as_scalar);
1392     }
1393     case F16: {
1394       return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
1395                                                            extract_as_scalar);
1396     }
1397     case U16: {
1398       return ExtractLiteralFromIndexPositions<uint16>(from, indices,
1399                                                       extract_as_scalar);
1400     }
1401     case S16: {
1402       return ExtractLiteralFromIndexPositions<int16>(from, indices,
1403                                                      extract_as_scalar);
1404     }
1405     case F32: {
1406       return ExtractLiteralFromIndexPositions<float>(from, indices,
1407                                                      extract_as_scalar);
1408     }
1409     case U32: {
1410       return ExtractLiteralFromIndexPositions<uint32>(from, indices,
1411                                                       extract_as_scalar);
1412     }
1413     case S32: {
1414       return ExtractLiteralFromIndexPositions<int32>(from, indices,
1415                                                      extract_as_scalar);
1416     }
1417     case F64: {
1418       return ExtractLiteralFromIndexPositions<double>(from, indices,
1419                                                       extract_as_scalar);
1420     }
1421     case U64: {
1422       return ExtractLiteralFromIndexPositions<uint64>(from, indices,
1423                                                       extract_as_scalar);
1424     }
1425     case S64: {
1426       return ExtractLiteralFromIndexPositions<int64>(from, indices,
1427                                                      extract_as_scalar);
1428     }
1429     default:
1430       return InvalidArgument("Unsupported type for Sort: %s",
1431                              PrimitiveType_Name(type));
1432   }
1433 }
1434 }  // namespace
1435 
HandleSort(HloInstruction * sort)1436 Status HloEvaluator::HandleSort(HloInstruction* sort) {
1437   TF_RET_CHECK(sort->operand_count() >= 1)
1438       << "Expected at least 1 operand for sort";
1439   for (int64 i = 1; i < sort->operand_count(); ++i) {
1440     TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
1441                                            sort->operand(i)->shape()))
1442         << "All Sort operands must have the same dimensions";
1443   }
1444 
1445   if (VLOG_IS_ON(3)) {
1446     for (int64 i = 0; i < sort->operand_count(); ++i) {
1447       VLOG(3) << "HandleSort operand " << i << " literal: "
1448               << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
1449     }
1450   }
1451   Shape key_shape = sort->operand(0)->shape();
1452   auto rank = key_shape.rank();
1453   std::vector<Literal> result_literals;
1454   result_literals.reserve(sort->operand_count());
1455   for (int64 i = 0; i < sort->operand_count(); ++i) {
1456     result_literals.emplace_back(sort->operand(i)->shape());
1457   }
1458   std::vector<int64> zero_base(rank, 0);
1459   std::vector<int64> increment(rank, 1);
1460   int64 sort_dim = sort->dimensions(0);
1461   int64 sort_dim_elements = key_shape.dimensions(sort_dim);
1462   increment[sort_dim] = sort_dim_elements;
1463   HloEvaluator embedded_evaluator(max_loop_iterations_);
1464   // Iterate through each dimension except 'sort_dim'.
1465   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1466       key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
1467       [&](absl::Span<const int64> indices) -> StatusOr<bool> {
1468         // Extract a slice from each operand literal that corresponds to
1469         // exactly the row in dimension 'sort_dim'.
1470         std::vector<int64> limit_indices(indices.begin(), indices.end());
1471         absl::c_for_each(limit_indices, [](int64& index) { ++index; });
1472         limit_indices[sort_dim] = sort_dim_elements;
1473         std::vector<Literal> literals_to_sort;
1474         literals_to_sort.reserve(sort->operand_count());
1475         for (int64 i = 0; i < sort->operand_count(); ++i) {
1476           TF_ASSIGN_OR_RETURN(auto literal_to_sort,
1477                               GetEvaluatedLiteralFor(sort->operand(i))
1478                                   .Slice(indices, limit_indices)
1479                                   .Reshape({sort_dim_elements}));
1480           literals_to_sort.push_back(std::move(literal_to_sort));
1481         }
1482         std::vector<int64> indices_to_sort(sort_dim_elements);
1483         std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
1484         Status compare_status = Status::OK();
1485         auto comparator = [sort, &compare_status, &embedded_evaluator,
1486                            &literals_to_sort](int64 a, int64 b) {
1487           std::vector<Literal> literals;
1488           literals.reserve(2 * sort->operand_count());
1489           for (int64 i = 0; i < sort->operand_count(); ++i) {
1490             auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
1491                                                  /*extract_as_scalar=*/true);
1492             if (!lhs.ok()) {
1493               compare_status = lhs.status();
1494               return false;
1495             }
1496             literals.push_back(std::move(lhs.ValueOrDie()));
1497             auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
1498                                                  /*extract_as_scalar=*/true);
1499             if (!rhs.ok()) {
1500               compare_status = rhs.status();
1501               return false;
1502             }
1503             literals.push_back(std::move(rhs.ValueOrDie()));
1504           }
1505           std::vector<const Literal*> literal_ptrs;
1506           absl::c_transform(literals, std::back_inserter(literal_ptrs),
1507                             [](const Literal& literal) { return &literal; });
1508 
1509           auto computed_result =
1510               embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
1511           // Clear visit states so that we can use the evaluator again
1512           // on the same computation.
1513           embedded_evaluator.ResetVisitStates();
1514           if (!computed_result.ok()) {
1515             compare_status = computed_result.status();
1516             return false;
1517           }
1518           return computed_result.ValueOrDie().Get<bool>({});
1519         };
1520         if (Cast<HloSortInstruction>(sort)->is_stable()) {
1521           std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
1522                            comparator);
1523         } else {
1524           std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
1525         }
1526         if (!compare_status.ok()) {
1527           return compare_status;
1528         }
1529         std::vector<int64> slice_dimensions(rank, 1);
1530         slice_dimensions[sort_dim] = sort_dim_elements;
1531         std::vector<int64> start_indices(rank, 0);
1532         for (int64 i = 0; i < sort->operand_count(); ++i) {
1533           TF_ASSIGN_OR_RETURN(
1534               Literal sorted_literal,
1535               ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
1536           TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
1537                               sorted_literal.Reshape(slice_dimensions));
1538           TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
1539               sorted_literal_reshaped, start_indices, indices,
1540               slice_dimensions));
1541         }
1542         return true;
1543       }));
1544 
1545   if (sort->operand_count() == 1) {
1546     evaluated_[sort] = std::move(result_literals[0]);
1547   } else {
1548     std::vector<const Literal*> literal_ptrs;
1549     absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
1550                       [](const Literal& literal) { return &literal; });
1551 
1552     Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
1553     VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
1554 
1555     evaluated_[sort] = std::move(result_tuple);
1556   }
1557   return Status::OK();
1558 }
1559 
HandleReduce(HloInstruction * reduce)1560 Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
1561   if (!reduce->shape().IsTuple()) {
1562     return DefaultAction(reduce);
1563   } else {
1564     auto first_element_type = reduce->shape().tuple_shapes(0).element_type();
1565     for (const auto& tuple_shape : reduce->shape().tuple_shapes()) {
1566       if (tuple_shape.element_type() != first_element_type) {
1567         return Unimplemented(
1568             "Reduce with several outputs that have mixed element types is "
1569             "unsupported");
1570       }
1571     }
1572     return reduce->Visit(typed_visitors_[first_element_type].get());
1573   }
1574 }
1575 
HandleCustomCall(HloInstruction * custom_call)1576 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
1577   if (!custom_call_handler_) {
1578     // No handler is registered; this means custom-calls are not allowed.
1579     return DefaultAction(custom_call);
1580   }
1581 
1582   // Evaluate input operands so the handler has access to the operand data.
1583   std::vector<const Literal*> operands;
1584   operands.reserve(custom_call->operand_count());
1585   for (const HloInstruction* operand : custom_call->operands()) {
1586     operands.push_back(&GetEvaluatedLiteralFor(operand));
1587   }
1588 
1589   // Synchronously issue the handler to populate the instruction output literal.
1590   TF_ASSIGN_OR_RETURN(
1591       auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
1592 
1593   evaluated_[custom_call] = std::move(output);
1594   return Status::OK();
1595 }
1596 
Preprocess(HloInstruction * hlo)1597 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
1598   VLOG(2) << "About to visit HLO: " << hlo->ToString();
1599   return ShapeUtil::ValidateShape(hlo->shape());
1600 }
1601 
Postprocess(HloInstruction * hlo)1602 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
1603   VLOG(2) << "Finished visiting " << hlo->ToString()
1604           << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
1605   // Out of convenience the literal may have been produced with a different
1606   // layout. Relayout as indicated by the HLO instruction.
1607   if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
1608                                         hlo->shape())) {
1609     evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
1610   }
1611   return Status::OK();
1612 }
1613 
1614 namespace {
1615 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64 m,int64 n,int64 k,int32 transpose_lhs,int32 transpose_rhs)> & impl_fn)1616 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
1617     const Array2D<T>& lhs, const Array2D<T>& rhs,
1618     const std::function<void(
1619         const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n,
1620         int64 k, int32 transpose_lhs, int32 transpose_rhs)>& impl_fn) {
1621   CHECK_EQ(lhs.width(), rhs.height());
1622   int m = lhs.height();
1623   int n = rhs.width();
1624   int k = lhs.width();
1625   auto result = absl::make_unique<Array2D<T>>(m, n);
1626   // Because Eigen is a header-oriented library, make sure that the Eigen code
1627   // is the same as the code used by the CPU backend (otherwise the linker will
1628   // randomly pick *some* definition).
1629   impl_fn(
1630       /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
1631       k,
1632       /*transpose_lhs=*/0,
1633       /*transpose_rhs=*/0);
1634   return result;
1635 }
1636 }  // namespace
1637 
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)1638 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
1639     const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
1640   return MatmulArray2DImpl<Eigen::half>(
1641       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
1642 }
1643 
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)1644 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
1645     const Array2D<float>& lhs, const Array2D<float>& rhs) {
1646   return MatmulArray2DImpl<float>(
1647       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
1648 }
1649 
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)1650 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
1651     const Array2D<double>& lhs, const Array2D<double>& rhs) {
1652   return MatmulArray2DImpl<double>(
1653       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
1654 }
1655 
1656 }  // namespace xla
1657