• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <complex>
20 #include <cstdlib>
21 #include <functional>
22 #include <iterator>
23 #include <string>
24 #include <type_traits>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/index_util.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/hlo_query.h"
43 #include "tensorflow/compiler/xla/service/shape_inference.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/types.h"
55 
56 namespace xla {
57 
58 namespace {
59 
60 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)61 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
62                           LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
63   std::function<bool(OperandT, OperandT)> compare_op;
64   switch (direction) {
65     case ComparisonDirection::kEq:
66       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
67         return lhs_el == rhs_el;
68       };
69       break;
70     case ComparisonDirection::kNe:
71       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
72         return lhs_el != rhs_el;
73       };
74       break;
75     case ComparisonDirection::kGe:
76       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
77         return lhs_el >= rhs_el;
78       };
79       break;
80     case ComparisonDirection::kGt:
81       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
82         return lhs_el > rhs_el;
83       };
84       break;
85     case ComparisonDirection::kLe:
86       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
87         return lhs_el <= rhs_el;
88       };
89       break;
90     case ComparisonDirection::kLt:
91       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
92         return lhs_el < rhs_el;
93       };
94       break;
95   }
96 
97   Literal result(shape);
98   TF_RETURN_IF_ERROR(
99       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
100         return compare_op(lhs_literal.Get<OperandT>(multi_index),
101                           rhs_literal.Get<OperandT>(multi_index));
102       }));
103 
104   return std::move(result);
105 }
106 
107 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)108 StatusOr<Literal> Compare<complex64>(const Shape& shape,
109                                      ComparisonDirection direction,
110                                      LiteralSlice lhs_literal,
111                                      LiteralSlice rhs_literal) {
112   std::function<bool(complex64, complex64)> compare_op;
113   switch (direction) {
114     case ComparisonDirection::kEq:
115       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
116         return lhs_el == rhs_el;
117       };
118       break;
119     case ComparisonDirection::kNe:
120       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
121         return lhs_el != rhs_el;
122       };
123       break;
124     default:
125       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
126                  << ComparisonDirectionToString(direction);
127   }
128 
129   Literal result(shape);
130   TF_RETURN_IF_ERROR(
131       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
132         return compare_op(lhs_literal.Get<complex64>(multi_index),
133                           rhs_literal.Get<complex64>(multi_index));
134       }));
135 
136   return std::move(result);
137 }
138 
139 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)140 StatusOr<Literal> Compare<complex128>(const Shape& shape,
141                                       ComparisonDirection direction,
142                                       LiteralSlice lhs_literal,
143                                       LiteralSlice rhs_literal) {
144   std::function<bool(complex128, complex128)> compare_op;
145   switch (direction) {
146     case ComparisonDirection::kEq:
147       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
148         return lhs_el == rhs_el;
149       };
150       break;
151     case ComparisonDirection::kNe:
152       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
153         return lhs_el != rhs_el;
154       };
155       break;
156     default:
157       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
158                  << ComparisonDirectionToString(direction);
159   }
160 
161   Literal result(shape);
162   TF_RETURN_IF_ERROR(
163       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
164         return compare_op(lhs_literal.Get<complex128>(multi_index),
165                           rhs_literal.Get<complex128>(multi_index));
166       }));
167 
168   return std::move(result);
169 }
170 
171 }  // namespace
172 
173 // Note that unsupported types by the typed visitor does not necessarily imply
174 // the non-typed HloEvaluator (parent evaluator) would not support them either
175 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
176 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
177 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64 max_loop_iterations)178 HloEvaluator::HloEvaluator(int64 max_loop_iterations)
179     : max_loop_iterations_(max_loop_iterations) {
180   typed_visitors_[PRED] =
181       absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
182   typed_visitors_[U8] =
183       absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
184   typed_visitors_[U16] =
185       absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
186   typed_visitors_[U32] =
187       absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
188   typed_visitors_[U64] =
189       absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
190   typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
191   typed_visitors_[S16] =
192       absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
193   typed_visitors_[S32] =
194       absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
195   typed_visitors_[S64] =
196       absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
197   typed_visitors_[F16] =
198       absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
199   typed_visitors_[F32] =
200       absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
201   typed_visitors_[F64] =
202       absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
203   typed_visitors_[C64] =
204       absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
205   typed_visitors_[C128] =
206       absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
207 
208   // Most of the evaluator computations we use don't support BF16 (e.g.,
209   // std::ceil, std::tanh). To make evaluator work with BF16, we set all
210   // elementwise computations to be done in F32 and do BF16<->F32 conversion
211   // around the input and the output of the computations.
212   typed_visitors_[BF16] =
213       absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
214 
215   typed_visitors_[TUPLE] =
216       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
217         return Unimplemented(
218             "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
219       });
220   typed_visitors_[OPAQUE_TYPE] =
221       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
222         return Unimplemented(
223             "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE.");
224       });
225   typed_visitors_[TOKEN] =
226       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
227         return Unimplemented(
228             "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
229       });
230 }
231 
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)232 StatusOr<Literal> HloEvaluator::Evaluate(
233     const HloComputation& computation,
234     absl::Span<const Literal* const> arg_literals) {
235   CHECK(computation.parent() != nullptr);
236   XLA_VLOG_LINES(
237       2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
238 
239   if (arg_literals.size() != computation.num_parameters()) {
240     return InvalidArgument(
241         "Expected %d argument%s, but got %d.", computation.num_parameters(),
242         computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
243   }
244   for (int64 i = 0; i < arg_literals.size(); ++i) {
245     const auto& computation_shape =
246         computation.parameter_instruction(i)->shape();
247     const auto& arg_shape = arg_literals[i]->shape();
248     if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
249                                                    arg_shape)) {
250       return InvalidArgument(
251           "Shape mismatch at parameter %d. Computation expected %s, but arg "
252           "was %s.",
253           i, ShapeUtil::HumanStringWithLayout(computation_shape),
254           ShapeUtil::HumanStringWithLayout(arg_shape));
255     }
256   }
257 
258   evaluated_.clear();
259   arg_literals_.clear();
260   for (const auto& literal_ptr : arg_literals) {
261     arg_literals_.push_back(&*literal_ptr);
262   }
263 
264   // Re-seed RNG, either from the configuration's seed or a monotonic
265   // per-evaluator seed (which prevents two evaluators from returning the same
266   // random sequence).
267   if (computation.parent()->config().seed()) {
268     seed_ = computation.parent()->config().seed();
269   } else {
270     // Start global_seed at a (true) random value.
271     static std::atomic<uint64> global_seed{std::random_device()()};
272     seed_ = global_seed.fetch_add(1);
273   }
274   engine_.seed(seed_);
275 
276   TF_RETURN_IF_ERROR(computation.Accept(this));
277 
278   if (VLOG_IS_ON(100)) {
279     for (const HloInstruction* instr : computation.instructions()) {
280       VLOG(100) << instr->name() << " = " << GetEvaluatedLiteralFor(instr);
281     }
282   }
283 
284   return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
285 }
286 
Evaluate(HloInstruction * instruction)287 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
288   // If the instruction is a kCopyDone, simply find the argument that it is
289   // copied from.
290   while (instruction->opcode() == HloOpcode::kCopyDone) {
291     if (instruction->operand(0)->opcode() != HloOpcode::kCopyStart) {
292       return tensorflow::errors::FailedPrecondition(
293           "kCopyDone has an argument different than a kCopyStart.");
294     }
295     instruction = instruction->mutable_operand(0)->mutable_operand(0);
296   }
297   if (instruction->opcode() == HloOpcode::kParameter) {
298     return tensorflow::errors::FailedPrecondition(
299         "Cannot evaluate a parameter.");
300   }
301   if (!hlo_query::AllOperandsAreConstants(*instruction)) {
302     return tensorflow::errors::FailedPrecondition(
303         "Not all operands are constants.");
304   }
305 
306   arg_literals_.clear();
307   evaluated_.clear();
308 
309   TF_RETURN_IF_ERROR(Preprocess(instruction));
310   TF_RETURN_IF_ERROR(instruction->Visit(this));
311   TF_RETURN_IF_ERROR(Postprocess(instruction));
312   return GetEvaluatedLiteralFor(instruction).Clone();
313 }
314 
TryEvaluate(HloInstruction * instruction,Literal * result)315 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
316   CHECK(result != nullptr);
317   auto result_or = Evaluate(instruction);
318   if (!result_or.ok()) {
319     VLOG(1) << "TryEvaluate failed:" << result_or.status();
320     return false;
321   }
322 
323   *result = result_or.ConsumeValueOrDie();
324   return true;
325 }
326 
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)327 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
328     const HloInstruction* instruction,
329     const std::unordered_map<const HloInstruction*, const Literal*>&
330         substitutions) {
331   std::vector<std::unique_ptr<HloInstruction>> owned_operands;
332   for (const HloInstruction* operand : instruction->operands()) {
333     auto it = substitutions.find(operand);
334     if (it == substitutions.end()) {
335       owned_operands.push_back(operand->Clone());
336     } else {
337       owned_operands.push_back(
338           HloInstruction::CreateConstant(it->second->Clone()));
339     }
340   }
341 
342   std::vector<HloInstruction*> operands;
343   operands.reserve(owned_operands.size());
344   for (auto& operand : owned_operands) {
345     operands.push_back(operand.get());
346   }
347 
348   std::unique_ptr<HloInstruction> cloned_instruction =
349       instruction->CloneWithNewOperands(instruction->shape(), operands);
350   auto result = Evaluate(cloned_instruction.get());
351 
352   return result;
353 }
354 
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)355 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
356     HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
357   std::unique_ptr<HloInstruction> lhs_instr =
358       HloInstruction::CreateConstant(lhs.Clone());
359   std::unique_ptr<HloInstruction> rhs_instr =
360       HloInstruction::CreateConstant(rhs.Clone());
361 
362   std::unique_ptr<HloInstruction> cloned_instruction =
363       HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
364                                    rhs_instr.get());
365   auto result = Evaluate(cloned_instruction.get());
366 
367   return result;
368 }
369 
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)370 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
371     HloOpcode opcode, const Literal& operand) {
372   std::unique_ptr<HloInstruction> operand_instr =
373       HloInstruction::CreateConstant(operand.Clone());
374 
375   std::unique_ptr<HloInstruction> cloned_instruction =
376       HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
377   auto result = Evaluate(cloned_instruction.get());
378 
379   return result;
380 }
381 
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)382 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
383     const DotDimensionNumbers& dim_numbers,
384     const PrecisionConfig& precision_config, const Literal& lhs,
385     const Literal& rhs) {
386   std::unique_ptr<HloInstruction> lhs_instr =
387       HloInstruction::CreateConstant(lhs.Clone());
388   std::unique_ptr<HloInstruction> rhs_instr =
389       HloInstruction::CreateConstant(rhs.Clone());
390 
391   TF_ASSIGN_OR_RETURN(Shape dot_shape,
392                       ShapeInference::InferDotOpShape(
393                           lhs.shape(), rhs.shape(), dim_numbers,
394                           /*preferred_element_type=*/absl::nullopt));
395 
396   std::unique_ptr<HloInstruction> cloned_instruction =
397       HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
398                                 dim_numbers, precision_config);
399   return Evaluate(cloned_instruction.get());
400 }
401 
HandleBitcast(HloInstruction * bitcast)402 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
403   const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
404   Literal result(bitcast->shape());
405   TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
406   memcpy(result.untyped_data(), operand_literal.untyped_data(),
407          operand_literal.size_bytes());
408   evaluated_[bitcast] = std::move(result);
409   return Status::OK();
410 }
411 
HandleGetDimensionSize(HloInstruction * get_dimension_size)412 Status HloEvaluator::HandleGetDimensionSize(
413     HloInstruction* get_dimension_size) {
414   HloInstruction* operand = get_dimension_size->mutable_operand(0);
415   int64 dim = get_dimension_size->dimension();
416   if (dynamic_dimension_inference_ == nullptr) {
417     return InvalidArgument(
418         "Evaluator cannot evaluate get_dimension_size without "
419         "set_dynamic_dimension_inference.");
420   }
421   HloInstruction* dynamic_size =
422       dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
423   if (dynamic_size != nullptr) {
424     evaluated_[get_dimension_size] =
425         GetEvaluatedLiteralFor(dynamic_size).Clone();
426     return Status::OK();
427   }
428 
429   const Shape& shape = get_dimension_size->operand(0)->shape();
430   Literal output(ShapeUtil::MakeShape(S32, {}));
431   output.PopulateWithValue(
432       static_cast<int32>(shape.dimensions(get_dimension_size->dimension())));
433   evaluated_[get_dimension_size] = std::move(output);
434   return Status::OK();
435 }
436 
HandleSetDimensionSize(HloInstruction * set_dimension_size)437 Status HloEvaluator::HandleSetDimensionSize(
438     HloInstruction* set_dimension_size) {
439   const Literal& operand_literal =
440       GetEvaluatedLiteralFor(set_dimension_size->operand(0));
441   Literal result(set_dimension_size->shape());
442   memcpy(result.untyped_data(), operand_literal.untyped_data(),
443          operand_literal.size_bytes());
444   const Literal& size_literal =
445       GetEvaluatedLiteralFor(set_dimension_size->operand(1));
446   result.SetDynamicSize(set_dimension_size->dimension(),
447                         size_literal.Get<int32>({}));
448   evaluated_[set_dimension_size] = std::move(result);
449   return Status::OK();
450 }
451 
HandleParameter(HloInstruction * parameter)452 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
453   // Nothing to do other than sanity checks. Parameters' values are stored in
454   // arg_literals_.
455   CHECK_LT(parameter->parameter_number(), arg_literals_.size());
456 
457 #ifndef NDEBUG
458   const Literal* input_literal = arg_literals_[parameter->parameter_number()];
459   VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
460   DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
461                                                    input_literal->shape()))
462       << "parameter shape is: "
463       << ShapeUtil::HumanStringWithLayout(parameter->shape())
464       << ", but input literal shape is: "
465       << ShapeUtil::HumanStringWithLayout(input_literal->shape());
466 #endif
467 
468   return Status::OK();
469 }
470 
HandleConstant(HloInstruction *)471 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
472 
HandleReshape(HloInstruction * reshape)473 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
474   TF_ASSIGN_OR_RETURN(
475       evaluated_[reshape],
476       GetEvaluatedLiteralFor(reshape->operand(0))
477           .Reshape(AsInt64Slice(reshape->shape().dimensions())));
478   return Status::OK();
479 }
480 
HandleTranspose(HloInstruction * transpose)481 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
482   evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
483                               .Transpose(transpose->dimensions());
484   return Status::OK();
485 }
486 
HandleConcatenate(HloInstruction * concatenate)487 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
488   absl::Span<HloInstruction* const> operands(concatenate->operands());
489   // The result concatenate dimension is going to be the sum of all
490   // concatenate dimensions of the operands taking part of the operation.
491   const Shape& reference_shape = operands[0]->shape();
492   CHECK(reference_shape.IsArray());
493   const int64 rank = reference_shape.rank();
494   const int64 concat_dim = concatenate->dimensions()[0];
495   CHECK_GE(concat_dim, 0);
496   CHECK_LT(concat_dim, rank);
497 
498   DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
499                                     reference_shape.dimensions().end());
500 
501   for (int64 i = 1; i < operands.size(); ++i) {
502     const Shape& operand_shape = operands[i]->shape();
503     CHECK(operand_shape.IsArray());
504     // Accumulate the concat dimension from all tensors taking part to the
505     // operation.
506     concat_dimensions[concat_dim] +=
507         ShapeUtil::GetDimension(operand_shape, concat_dim);
508   }
509 
510   auto result_literal = LiteralUtil::CreateFromDimensions(
511       reference_shape.element_type(), concat_dimensions);
512   DimensionVector source_indices(rank, 0);
513   DimensionVector dest_indices(concat_dimensions.size(), 0);
514 
515   for (auto operand : operands) {
516     const Shape& operand_shape = operand->shape();
517     TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
518         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
519         AsInt64Slice(operand_shape.dimensions())));
520     dest_indices[concat_dim] +=
521         ShapeUtil::GetDimension(operand_shape, concat_dim);
522   }
523 
524   evaluated_[concatenate] = std::move(result_literal);
525   return Status::OK();
526 }
527 
HandleIsFinite(HloInstruction * is_finite)528 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
529   auto operand = is_finite->operand(0);
530   auto elem_ty = operand->shape().element_type();
531   switch (elem_ty) {
532     case PRED:
533     case TUPLE:
534     case OPAQUE_TYPE:
535     case TOKEN:
536     case S8:
537     case S16:
538     case S32:
539     case S64:
540     case U8:
541     case U16:
542     case U32:
543     case U64:
544     case C64:
545     case C128:
546     // Explicitly enumerate all types in this switch so that when we add a new
547     // type, we'll get a compile error here.
548     case PRIMITIVE_TYPE_INVALID:
549     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
550     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
551       return InvalidArgument(
552           "expected element type in shape to be floating point, but "
553           "got: %s",
554           PrimitiveType_Name(elem_ty));
555 
556     case F16: {
557       auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
558           is_finite,
559           [](Eigen::half elem_operand) {
560             return std::isfinite(static_cast<float>(elem_operand));
561           },
562           GetEvaluatedLiteralFor(operand));
563       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
564       break;
565     }
566     case BF16: {
567       auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
568           is_finite,
569           [](bfloat16 elem_operand) {
570             return std::isfinite(static_cast<float>(elem_operand));
571           },
572           GetEvaluatedLiteralFor(operand));
573       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
574       break;
575     }
576     case F32: {
577       auto result_or = ElementWiseUnaryOpImpl<bool, float>(
578           is_finite,
579           [](float elem_operand) { return std::isfinite(elem_operand); },
580           GetEvaluatedLiteralFor(operand));
581       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
582       break;
583     }
584     case F64: {
585       auto result_or = ElementWiseUnaryOpImpl<bool, double>(
586           is_finite,
587           [](double elem_operand) { return std::isfinite(elem_operand); },
588           GetEvaluatedLiteralFor(operand));
589       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
590       break;
591     }
592   }
593 
594   return Status::OK();
595 }
596 
HandleReal(HloInstruction * real)597 Status HloEvaluator::HandleReal(HloInstruction* real) {
598   auto operand = real->operand(0);
599   switch (operand->shape().element_type()) {
600     case BF16: {
601       auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
602           real, [](bfloat16 elem_operand) { return elem_operand; },
603           GetEvaluatedLiteralFor(operand));
604       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
605       break;
606     }
607     case C64: {
608       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
609           real, [](complex64 elem_operand) { return std::real(elem_operand); },
610           GetEvaluatedLiteralFor(operand));
611       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
612       break;
613     }
614     case C128: {
615       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
616           real, [](complex128 elem_operand) { return std::real(elem_operand); },
617           GetEvaluatedLiteralFor(operand));
618       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
619       break;
620     }
621     case F16: {
622       auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
623           real, [](Eigen::half elem_operand) { return elem_operand; },
624           GetEvaluatedLiteralFor(operand));
625       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
626       break;
627     }
628     case F32: {
629       auto result_or = ElementWiseUnaryOpImpl<float, float>(
630           real, [](float elem_operand) { return elem_operand; },
631           GetEvaluatedLiteralFor(operand));
632       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
633       break;
634     }
635     case F64: {
636       auto result_or = ElementWiseUnaryOpImpl<double, double>(
637           real, [](double elem_operand) { return elem_operand; },
638           GetEvaluatedLiteralFor(operand));
639       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
640       break;
641     }
642     default:
643       LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
644                  << PrimitiveType_Name(operand->shape().element_type());
645   }
646 
647   return Status::OK();
648 }
649 
HandleImag(HloInstruction * imag)650 Status HloEvaluator::HandleImag(HloInstruction* imag) {
651   auto operand = imag->operand(0);
652   switch (operand->shape().element_type()) {
653     case C64: {
654       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
655           imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
656           GetEvaluatedLiteralFor(imag->operand(0)));
657 
658       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
659       break;
660     }
661     case C128: {
662       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
663           imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
664           GetEvaluatedLiteralFor(imag->operand(0)));
665 
666       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
667       break;
668     }
669     default:
670       LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
671                  << PrimitiveType_Name(operand->shape().element_type());
672   }
673 
674   return Status::OK();
675 }
676 
HandleComplex(HloInstruction * complex)677 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
678   const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
679   const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
680   TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
681 
682   Literal result(complex->shape());
683   switch (complex->shape().element_type()) {
684     case C64: {
685       TF_RETURN_IF_ERROR(
686           result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
687             return std::complex<float>(real.Get<float>(multi_index),
688                                        imag.Get<float>(multi_index));
689           }));
690       break;
691     }
692     case C128: {
693       TF_RETURN_IF_ERROR(
694           result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
695             return std::complex<double>(real.Get<double>(multi_index),
696                                         imag.Get<double>(multi_index));
697           }));
698       break;
699     }
700     default:
701       LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
702                  << PrimitiveType_Name(complex->shape().element_type());
703   }
704 
705   evaluated_[complex] = std::move(result);
706   return Status::OK();
707 }
708 
HandleCompare(HloInstruction * compare)709 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
710   ComparisonDirection direction = compare->comparison_direction();
711   auto lhs = compare->operand(0);
712   auto rhs = compare->operand(1);
713   DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
714          ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
715 
716   TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
717 
718   const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
719   const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
720 
721   // Note here we switch on the operand's type.
722   switch (lhs->shape().element_type()) {
723     case PRED: {
724       TF_ASSIGN_OR_RETURN(
725           evaluated_[compare],
726           Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
727     } break;
728     case U8: {
729       TF_ASSIGN_OR_RETURN(evaluated_[compare],
730                           Compare<uint8>(compare->shape(), direction,
731                                          lhs_literal, rhs_literal));
732     } break;
733     case U16: {
734       TF_ASSIGN_OR_RETURN(evaluated_[compare],
735                           Compare<uint16>(compare->shape(), direction,
736                                           lhs_literal, rhs_literal));
737     } break;
738     case U32: {
739       TF_ASSIGN_OR_RETURN(evaluated_[compare],
740                           Compare<uint32>(compare->shape(), direction,
741                                           lhs_literal, rhs_literal));
742     } break;
743     case U64: {
744       TF_ASSIGN_OR_RETURN(evaluated_[compare],
745                           Compare<uint64>(compare->shape(), direction,
746                                           lhs_literal, rhs_literal));
747     } break;
748     case S8: {
749       TF_ASSIGN_OR_RETURN(
750           evaluated_[compare],
751           Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
752     } break;
753     case S16: {
754       TF_ASSIGN_OR_RETURN(evaluated_[compare],
755                           Compare<int16>(compare->shape(), direction,
756                                          lhs_literal, rhs_literal));
757     } break;
758     case S32: {
759       TF_ASSIGN_OR_RETURN(evaluated_[compare],
760                           Compare<int32>(compare->shape(), direction,
761                                          lhs_literal, rhs_literal));
762     } break;
763     case S64: {
764       TF_ASSIGN_OR_RETURN(evaluated_[compare],
765                           Compare<int64>(compare->shape(), direction,
766                                          lhs_literal, rhs_literal));
767     } break;
768     case F16: {
769       TF_ASSIGN_OR_RETURN(
770           evaluated_[compare],
771           Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
772     } break;
773     case BF16: {
774       TF_ASSIGN_OR_RETURN(evaluated_[compare],
775                           Compare<bfloat16>(compare->shape(), direction,
776                                             lhs_literal, rhs_literal));
777     } break;
778     case F32: {
779       TF_ASSIGN_OR_RETURN(evaluated_[compare],
780                           Compare<float>(compare->shape(), direction,
781                                          lhs_literal, rhs_literal));
782     } break;
783     case F64: {
784       TF_ASSIGN_OR_RETURN(evaluated_[compare],
785                           Compare<double>(compare->shape(), direction,
786                                           lhs_literal, rhs_literal));
787     } break;
788     case C64: {
789       TF_ASSIGN_OR_RETURN(evaluated_[compare],
790                           Compare<complex64>(compare->shape(), direction,
791                                              lhs_literal, rhs_literal));
792     } break;
793     case C128: {
794       TF_ASSIGN_OR_RETURN(evaluated_[compare],
795                           Compare<complex128>(compare->shape(), direction,
796                                               lhs_literal, rhs_literal));
797     } break;
798     default:
799       LOG(FATAL) << "HandleCompare: unknown primitive type: "
800                  << PrimitiveType_Name(lhs->shape().element_type());
801   }
802 
803   return Status::OK();
804 }
805 
HandleTuple(HloInstruction * tuple)806 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
807   std::vector<const Literal*> operand_literals;
808   for (auto operand : tuple->operands()) {
809     operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
810   }
811 
812   evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
813   return Status::OK();
814 }
815 
816 namespace {
817 
818 // Common code used by 1D implementations, which copies data from the input to
819 // the contiguous buffer. Returns true if all copied values are zero.
GatherToBuffer(absl::Span<complex128> data,int64 length,int64 start,int64 stride,bool expand_input,absl::Span<complex128> buffer)820 bool GatherToBuffer(absl::Span<complex128> data, int64 length, int64 start,
821                     int64 stride, bool expand_input,
822                     absl::Span<complex128> buffer) {
823   CHECK_GE(buffer.size(), length);
824   bool input_is_zero = true;
825   const int64 ub = expand_input ? length / 2 + 1 : length;
826   CHECK_GE(data.size(), start + (ub - 1) * stride);
827   for (int64 k = 0; k < ub; k++) {
828     complex128 value = data[start + k * stride];
829     input_is_zero &= value == complex128(0.0, 0.0);
830     buffer[k] = value;
831     if (expand_input) {
832       // Use conjugates of the values at indices [1 ... (ub - 2)] when the
833       // length is even and at indices [1 ... (ub - 1)] when the length is odd
834       // to calculate missing values at indices [(length - 1) ... ub].
835       if (k > 0 && k < (length - ub + 1)) {
836         buffer[length - k] = std::conj(value);
837       }
838     }
839   }
840   return input_is_zero;
841 }
842 
843 // Returns (conjugated, if 'inverse' is true) k-th twiddle for the given length.
Twiddle(int64 k,int64 length,bool inverse)844 inline complex128 Twiddle(int64 k, int64 length, bool inverse) {
845   auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * k / length));
846   return inverse ? std::conj(coeff) : coeff;
847 }
848 
849 // Straightforward implementation of 1D DFT transform of arbitrary length. Uses
850 // passed-in start index and stride to gather inputs from the data vector into
851 // the preallocated buffer, computes the result, and writes it back to the same
852 // locations in the data vector. Runs in O(length^2) time.
853 //
854 // Parameters contract_output and expand_input are used to avoid unnecessary
855 // calculations. When contract_output is set to true, then only (length / 2) + 1
856 // output values are computed. When expand_input is set to true, then
857 // (length / 2) + 1 values from the data set are used to re-create the full set
858 // of size 'length', on which the transform is then performed.
859 //
NaiveDft1D(int64 length,int64 start,int64 stride,bool inverse,bool contract_output,bool expand_input,absl::Span<complex128> data,absl::Span<complex128> buffer)860 void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse,
861                 bool contract_output, bool expand_input,
862                 absl::Span<complex128> data, absl::Span<complex128> buffer) {
863   const bool input_is_zero =
864       GatherToBuffer(data, length, start, stride, expand_input, buffer);
865 
866   if (!input_is_zero) {
867     const int64 ub = contract_output ? length / 2 + 1 : length;
868     for (int64 k = 0; k < ub; k++) {
869       complex128 value = complex128(0.0, 0.0);
870       for (int n = 0; n < length; n++) {
871         value += buffer[n] * Twiddle(n * k, length, inverse);
872       }
873       data[start + k * stride] =
874           inverse ? value / complex128(length, 0.0) : value;
875     }
876   }
877 }
878 
879 // Non-recursive implementation of the Cooley-Tukey radix-2 decimation in time.
880 // Performs 1D FFT transform for the lengths, which are powers of 2. Runs in
881 // O(length * log(length)) time. Uses the same parameters as the naive
882 // implementation above, except that the preallocated buffer must be at least
883 // twice as big as the length of the transform, because the buffer is used to
884 // hold both input and output values for each stage of the transform.
885 //
Fft1D(int64 length,int64 start,int64 stride,bool inverse,bool contract_output,bool expand_input,absl::Span<complex128> data,absl::Span<complex128> buffer)886 void Fft1D(int64 length, int64 start, int64 stride, bool inverse,
887            bool contract_output, bool expand_input, absl::Span<complex128> data,
888            absl::Span<complex128> buffer) {
889   CHECK(IsPowerOfTwo(static_cast<uint64>(length)));
890   const bool input_is_zero =
891       GatherToBuffer(data, length, start, stride, expand_input, buffer);
892 
893   if (!input_is_zero) {
894     auto generate_twiddles = [](int64 length, bool inverse) {
895       std::vector<complex128> twiddles;
896       // Need only half the twiddles.
897       for (int64 k = 0; k < length / 2; k++) {
898         twiddles.push_back(Twiddle(k, length, inverse));
899       }
900       return twiddles;
901     };
902 
903     // Indices into the parts of the buffer used for input and output values.
904     int64 in_base = length;
905     int64 out_base = 0;
906 
907     // At each stage, we "split" the input data into num_blocks, with block_size
908     // values in each block.
909     for (int64 num_blocks = 1; num_blocks < length; num_blocks *= 2) {
910       // Swap input and output parts of the buffer.
911       std::swap(in_base, out_base);
912       auto twiddles = generate_twiddles(num_blocks * 2, inverse);
913       const int64 block_size = length / num_blocks;
914       const int64 next_iteration_block_size = block_size / 2;
915       for (int64 block = 0; block < num_blocks; block++) {
916         const int64 in_offset = in_base + block * block_size;
917         const int64 out_offset = out_base + block * next_iteration_block_size;
918         // For each (even, odd) pair of values in the block, calculate two
919         // output values as even + twiddle * odd and even - twiddle * odd.
920         for (int64 pair = 0; pair < block_size / 2; pair++) {
921           const complex128 even = buffer[in_offset + pair];
922           const complex128 odd = buffer[in_offset + block_size / 2 + pair];
923           const complex128 twiddled_odd = twiddles[block] * odd;
924           buffer[out_offset + pair] = even + twiddled_odd;
925           buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
926         }
927       }
928     }
929     // Copy computed result back to data.
930     const int64 ub = contract_output ? length / 2 + 1 : length;
931     for (int64 k = 0; k < ub; k++) {
932       complex128 value = buffer[out_base + k];
933       data[start + k * stride] =
934           inverse ? value / complex128(length, 0.0) : value;
935     }
936   }
937 }
938 
939 // Determine, which implementation of 1D transform to use and call it.
Dft1D(int64 length,int64 start,int64 stride,bool inverse,bool contract_output,bool expand_input,absl::Span<complex128> data,absl::Span<complex128> buffer)940 void Dft1D(int64 length, int64 start, int64 stride, bool inverse,
941            bool contract_output, bool expand_input, absl::Span<complex128> data,
942            absl::Span<complex128> buffer) {
943   if (IsPowerOfTwo(static_cast<uint64>(length))) {
944     Fft1D(length, start, stride, inverse, contract_output, expand_input, data,
945           buffer);
946   } else {
947     NaiveDft1D(length, start, stride, inverse, contract_output, expand_input,
948                data, buffer);
949   }
950 }
951 
952 // Helper to reverse the order of dimension lengths in the passed-in literal.
GetDimensionLengths(const Literal & literal)953 std::vector<int64> GetDimensionLengths(const Literal& literal) {
954   auto dimensions = literal.shape().dimensions();
955   return std::vector<int64>(dimensions.rbegin(), dimensions.rend());
956 }
957 
958 // Helper to compute strides for creating linear indices into multidimensional
959 // data from the dimension lengths and the layout. Returns a new vector of size
960 // lengths.size() + 1. The last element of the returned vector at index
961 // [lengths.size()] contains the product of all dimension lengths.
ComputeStrides(const absl::Span<const int64> lengths,const Layout & layout)962 std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
963                                   const Layout& layout) {
964   const int64 num_dimensions = lengths.size();
965 
966   // Make sure that the layout length matches the number of dimensions.
967   CHECK_EQ(num_dimensions, layout.minor_to_major_size());
968 
969   // Calculate strides using layout-specified ordering of the dimensions and
970   // place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
971   std::vector<int64> strides(num_dimensions + 1);
972   int64 stride = 1;
973   for (int64 i = 0; i < num_dimensions; i++) {
974     // Reverse the ordering of the dimensions in the layout.
975     const int64 index = (num_dimensions - 1) - layout.minor_to_major(i);
976     strides[index] = stride;
977     stride *= lengths[index];
978   }
979   strides[num_dimensions] = stride;
980 
981   return strides;
982 }
983 
984 // Compute strides as above using the default layout.
ComputeStrides(const absl::Span<const int64> lengths)985 std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths) {
986   return ComputeStrides(lengths,
987                         LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
988 }
989 
990 // Compute strides as above using the layout from the literal, if available.
ComputeStrides(const absl::Span<const int64> lengths,const Literal & literal)991 std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
992                                   const Literal& literal) {
993   return literal.shape().has_layout()
994              ? ComputeStrides(lengths, literal.shape().layout())
995              : ComputeStrides(lengths);
996 }
997 
998 // Make 1D sweeps along each transform axis.
Sweep(int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,absl::Span<complex128> data,absl::Span<complex128> buffer)999 void Sweep(int64 fft_rank, FftType fft_type,
1000            const absl::Span<const int64> fft_lengths,
1001            const absl::Span<const int64> fft_strides,
1002            absl::Span<complex128> data, absl::Span<complex128> buffer) {
1003   const bool inverse = fft_type == FftType::IFFT || fft_type == FftType::IRFFT;
1004   const bool input_is_truncated = fft_type == FftType::IRFFT;
1005   const bool output_is_truncated = fft_type == FftType::RFFT;
1006 
1007   // Recursively visit each column of the data along the sweep_axis. Calculate
1008   // linearized index of that column's first element and the stride, then invoke
1009   // 1D transform.
1010   // For RFFT, avoid calculating unused output values: first, compute only
1011   // (length_x / 2) + 1 values along the X axis, then limit the X coordinate to
1012   // [0 ... (length / 2)] during the sweeps along other axes. Similarly, for
1013   // IRFFT sweep along higher dimensions first, while keeping the X coordinate
1014   // in the [0 ... (length / 2)] range, then re-create negative frequencies
1015   // omitted in the input and perform the full-length transform along the X axis
1016   // in the last sweep.
1017   std::function<void(int64, int64, int64)> sweep = [&](int64 sweep_axis,
1018                                                        int64 axis,
1019                                                        int64 start) {
1020     if (axis < 0) {
1021       // Base case: invoke 1D transform.
1022       const int64 length = fft_lengths[sweep_axis];
1023       const int64 stride = fft_strides[sweep_axis];
1024       const bool expand_input = input_is_truncated && sweep_axis == 0;
1025       const bool contract_oputput = output_is_truncated && sweep_axis == 0;
1026       Dft1D(length, start, stride, inverse, contract_oputput, expand_input,
1027             data, buffer);
1028     } else if (axis == sweep_axis) {
1029       // Visit only the elements with coordinate 0 along the sweep axis.
1030       sweep(sweep_axis, axis - 1, start);
1031     } else {
1032       const int64 length = fft_lengths[axis];
1033       const bool is_truncated = input_is_truncated || output_is_truncated;
1034       const int64 ub = is_truncated && axis == 0 ? (length / 2) + 1 : length;
1035       for (int64 i = 0; i < ub; i++) {
1036         sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
1037       }
1038     }
1039   };
1040   if (input_is_truncated) {
1041     // Sweep along the X axis last for IRFFT.
1042     for (int64 sweep_axis = fft_rank - 1; sweep_axis >= 0; sweep_axis--) {
1043       sweep(sweep_axis, fft_rank - 1, 0);
1044     }
1045   } else {
1046     // Sweep along the X axis first for RFFT. The order does not matter for FFT
1047     // and IFFT types; handle them here as well.
1048     for (int64 sweep_axis = 0; sweep_axis < fft_rank; sweep_axis++) {
1049       sweep(sweep_axis, fft_rank - 1, 0);
1050     }
1051   }
1052 }
1053 
1054 // These templates convert the data from the input data type to the type used in
1055 // calculations and then to the output data type. They are intended to be used
1056 // only within the DFT implementation. One special case is IRFFT, where the
1057 // specialization drops imaginary parts of complex values (which is expected to
1058 // be 0) and returns real numbers.
1059 template <typename ToType, typename FromType>
GetAs(FromType value)1060 ToType GetAs(FromType value) {
1061   return static_cast<ToType>(value);
1062 }
1063 
1064 template <>
GetAs(complex128 value)1065 float GetAs<float, complex128>(complex128 value) {
1066   return static_cast<float>(value.real());
1067 }
1068 
1069 // This template generates two linearized indices, which can be used to access
1070 // multidimensional arrays. It uses a recursive function, which passes the
1071 // indices to the user-supplied callback function. The destination index is
1072 // always within dst_lengths[] bounds. The boolean parameter within_src_bounds
1073 // indicates whether the source index is within src_lengths[] bounds.
1074 //
1075 // The value returned from the callback function controls the recursion depth.
1076 // Returning true indicates that the base case had been hit and the recursion
1077 // stops. Otherwise, the recursion proceeds along the next less-major axis.
1078 //
1079 // For example, the base case when the axis value becomes negative invokes the
1080 // callback function for each possible index within dst_lengths[] bounds. The
1081 // base case when the axis value is equal to zero limits the indices to point
1082 // only to first elements along the minor-most dimension, allowing the callback
1083 // function to handle all values along the X axis.
1084 //
1085 template <typename BaseFn>
GenerateIndices(const absl::Span<const int64> dst_lengths,const absl::Span<const int64> dst_strides,const absl::Span<const int64> src_lengths,const absl::Span<const int64> src_strides,int64 fft_rank,int64 dst_start,int64 src_start,BaseFn && base)1086 void GenerateIndices(const absl::Span<const int64> dst_lengths,
1087                      const absl::Span<const int64> dst_strides,
1088                      const absl::Span<const int64> src_lengths,
1089                      const absl::Span<const int64> src_strides, int64 fft_rank,
1090                      int64 dst_start, int64 src_start, BaseFn&& base) {
1091   CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
1092   CHECK_GE(dst_lengths.size(), fft_rank);
1093   CHECK_EQ(src_lengths.size() + 1, src_strides.size());
1094   CHECK_GE(src_lengths.size(), fft_rank);
1095 
1096   std::function<void(int64, int64, int64, bool)> generate =
1097       [&](int64 axis, int64 dst_index, int64 src_index,
1098           bool within_src_bounds) {
1099         if (!base(axis, dst_index, src_index, within_src_bounds)) {
1100           for (int64 i = 0; i < dst_lengths[axis]; i++) {
1101             // Because the loop goes over dst_lengths[], the source index may be
1102             // out of src_lengths[] bounds. In this case, within_src_bounds is
1103             // false.
1104             within_src_bounds &= i < src_lengths[axis];
1105             generate(axis - 1, dst_index, src_index, within_src_bounds);
1106             dst_index += dst_strides[axis];
1107             src_index += src_strides[axis];
1108           }
1109         }
1110       };
1111   generate(fft_rank - 1, dst_start, src_start, true);
1112 }
1113 
1114 // Copies the input data from a literal to a pre-allocated vector. The sizes of
1115 // the input and the transform do not need to match. For each axis of the
1116 // transform, any extra input values beyond the transform length are ignored.
1117 // Conversely, if the input does not contain enough elements along any axis, the
1118 // data is padded with zeroes.
1119 //
1120 // For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
1121 // where length_x is the size of the full transform along the X axis.
1122 //
1123 // The input literal may have a rank higher than the rank of the transform.
1124 // Passed-in input_index value points to the first element of the input literal
1125 // to be copied.
1126 //
1127 // Returns true if all values in the work data set are zeroes.
1128 //
1129 template <typename InputType>
CopyDataFromInput(const Literal & input_literal,int64 input_start,int64 fft_rank,FftType fft_type,int64 fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<complex128> data)1130 bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
1131                        int64 fft_rank, FftType fft_type, int64 fft_size,
1132                        const absl::Span<const int64> fft_lengths,
1133                        const absl::Span<const int64> fft_strides,
1134                        const absl::Span<const int64> input_lengths,
1135                        const absl::Span<const int64> input_strides,
1136                        absl::Span<complex128> data) {
1137   CHECK_GE(data.size(), fft_size);
1138 
1139   const bool input_is_truncated = fft_type == FftType::IRFFT;
1140 
1141   // Recursively visit each transform dimension to copy input values to the
1142   // working data set. The base case handles inputs along the X axis.
1143   bool input_is_zero = true;
1144   const InputType* input_data = input_literal.data<InputType>().data();
1145   auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
1146                        bool within_src_bounds) {
1147     if (axis == 0) {
1148       // For IRFFT, the negative frequencies are only needed for the sweep along
1149       // the X axis, which is performed last. Leave this part of the working set
1150       // uninitialized until then.
1151       const int64 length = fft_lengths[axis];
1152       const int64 ub = input_is_truncated ? (length / 2) + 1 : length;
1153       for (int64 i = 0; i < ub; i++) {
1154         complex128 value = InputType(0);
1155         // Read input value only if the index is within bounds.
1156         if (within_src_bounds && i < input_lengths[axis]) {
1157           value = GetAs<complex128, InputType>(
1158               input_data[src_index + i * input_strides[axis]]);
1159           input_is_zero &= value == complex128(0.0, 0.0);
1160         }
1161         data[dst_index + i * fft_strides[axis]] = value;
1162       }
1163       return true;
1164     }
1165     return false;
1166   };
1167   GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
1168                   fft_rank, 0, input_start, base_case);
1169   return input_is_zero;
1170 }
1171 
1172 // Copies the result of the transform to the literal output. The sizes of the
1173 // transform and output must match.
1174 //
1175 // For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
1176 // the size of the full transform along the X axis (the most minor dimension).
1177 //
1178 // The output literal may have a rank higher than the rank of the transform.
1179 // Passed-in output_index value points to the first element of the output
1180 // literal to be filled in.
1181 //
1182 template <typename OutputType>
CopyDataToOutput(const absl::Span<complex128> data,int64 output_start,int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1183 void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
1184                       int64 fft_rank, FftType fft_type,
1185                       const absl::Span<const int64> fft_lengths,
1186                       const absl::Span<const int64> fft_strides,
1187                       const absl::Span<const int64> output_lengths,
1188                       const absl::Span<const int64> output_strides,
1189                       Literal* output_literal) {
1190   const bool output_is_truncated = fft_type == FftType::RFFT;
1191 
1192   // Base case for recursive copy of the results to the output. The code avoids
1193   // making a recursive call for each output element by handling axis 0 in the
1194   // loop (as opposed to making "axis < 0" to be the base case).
1195   OutputType* output_data = output_literal->data<OutputType>().data();
1196   auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
1197                        bool within_src_bounds) {
1198     if (axis == 0) {
1199       // Drop negative frequencies for RFFT.
1200       const int64 length = fft_lengths[axis];
1201       const int64 ub = output_is_truncated ? (length / 2) + 1 : length;
1202       for (int64 i = 0; i < output_lengths[axis]; i++) {
1203         OutputType value = OutputType(0);
1204         // Read data only if the index is within bounds.
1205         if (within_src_bounds && i < ub) {
1206           value = GetAs<OutputType, complex128>(
1207               data[src_index + i * fft_strides[axis]]);
1208         }
1209         output_data[dst_index + i * output_strides[axis]] = value;
1210       }
1211       return true;
1212     }
1213     return false;
1214   };
1215   GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
1216                   fft_rank, output_start, 0, base_case);
1217 }
1218 
1219 // Determine the type to use with the CopyDataFromInput<> template above.
CopyDataFromInput(const Literal & input_literal,int64 input_start,int64 fft_rank,FftType fft_type,int64 fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<complex128> data)1220 bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
1221                        int64 fft_rank, FftType fft_type, int64 fft_size,
1222                        const absl::Span<const int64> fft_lengths,
1223                        const absl::Span<const int64> fft_strides,
1224                        const absl::Span<const int64> input_lengths,
1225                        const absl::Span<const int64> input_strides,
1226                        absl::Span<complex128> data) {
1227   const bool input_is_float = fft_type == FftType::RFFT;
1228   if (input_is_float) {
1229     return CopyDataFromInput<float>(
1230         input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
1231         fft_strides, input_lengths, input_strides, data);
1232   } else {
1233     return CopyDataFromInput<complex64>(
1234         input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
1235         fft_strides, input_lengths, input_strides, data);
1236   }
1237 }
1238 
1239 // Determine the type to use with the CopyDataToOutput<> template above.
CopyDataToOutput(const absl::Span<complex128> data,int64 output_start,int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1240 void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
1241                       int64 fft_rank, FftType fft_type,
1242                       const absl::Span<const int64> fft_lengths,
1243                       const absl::Span<const int64> fft_strides,
1244                       const absl::Span<const int64> output_lengths,
1245                       const absl::Span<const int64> output_strides,
1246                       Literal* output_literal) {
1247   const bool output_is_float = fft_type == FftType::IRFFT;
1248   if (output_is_float) {
1249     CopyDataToOutput<float>(data, output_start, fft_rank, fft_type, fft_lengths,
1250                             fft_strides, output_lengths, output_strides,
1251                             output_literal);
1252   } else {
1253     CopyDataToOutput<complex64>(data, output_start, fft_rank, fft_type,
1254                                 fft_lengths, fft_strides, output_lengths,
1255                                 output_strides, output_literal);
1256   }
1257 }
1258 
CheckParameters(const Shape & input_shape,const Shape & output_shape,int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths)1259 Status CheckParameters(const Shape& input_shape, const Shape& output_shape,
1260                        int64 fft_rank, FftType fft_type,
1261                        const absl::Span<const int64> fft_lengths) {
1262   // Check FFT parameters.
1263   if (fft_rank <= 0) {
1264     return InvalidArgument("Zero or negative FFT rank.");
1265   }
1266   if (*absl::c_min_element(fft_lengths) < 0) {
1267     return InvalidArgument("Negative FFT length.");
1268   }
1269 
1270   // Check input-related values.
1271   TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
1272   if (!input_shape.IsArray()) {
1273     return Unimplemented("Only array input shapes are supported.");
1274   }
1275   auto input_elt_type = input_shape.element_type();
1276   if (fft_type == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
1277     return InvalidArgument("Invalid input type: %d, must be %d (float).",
1278                            input_elt_type, PrimitiveType::F32);
1279   }
1280   if (fft_type != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
1281     return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
1282                            input_elt_type, PrimitiveType::C64);
1283   }
1284   const int64 input_rank = input_shape.rank();
1285   if (input_rank < fft_rank) {
1286     return InvalidArgument("Input shape rank is smaller than FFT rank.");
1287   }
1288 
1289   // Check output-related values.
1290   TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
1291   if (!output_shape.IsArray()) {
1292     return Unimplemented("Only array output shapes are supported.");
1293   }
1294   auto output_elt_type = output_shape.element_type();
1295   if (fft_type == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
1296     return InvalidArgument("Invalid output type: %d, must be %d (float).",
1297                            output_elt_type, PrimitiveType::F32);
1298   }
1299   if (fft_type != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
1300     return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
1301                            output_elt_type, PrimitiveType::C64);
1302   }
1303   const int64 output_rank = output_shape.rank();
1304   if (output_rank < fft_rank) {
1305     return InvalidArgument("Output shape rank is smaller than FFT rank.");
1306   }
1307 
1308   // Consistency of input and output parameters.
1309   if (input_rank != output_rank) {
1310     return InvalidArgument(
1311         "Ranks of input shape and output shape do not match.");
1312   }
1313   for (int64 dim = 0; dim < input_rank - fft_rank; dim++) {
1314     if (ShapeUtil::GetDimension(input_shape, dim) !=
1315         ShapeUtil::GetDimension(output_shape, dim)) {
1316       return InvalidArgument(
1317           "Higher dimension lengths of input shape and output shape do not "
1318           "match.");
1319     }
1320   }
1321 
1322   return Status::OK();
1323 }
1324 
1325 }  // namespace
1326 
1327 // Flexible implementation of the discrete Fourier transform. All transform
1328 // types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary
1329 // rank and length of each dimension of the transform, and arbitrary layouts of
1330 // the input and output literals.
1331 //
1332 // The input literal in operand 0 provides input data, which must be complex64
1333 // for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed
1334 // over the innermost dimensions of the input, thus the rank of the input data
1335 // must be same as fft_rank or larger. The input is expected to provide Ni
1336 // values along each transform axis with one exception: for IRFFT, only
1337 // (N0 / 2) + 1 values are needed along the X axis (the innermost index). To
1338 // increase flexibility, this implementation can handle mismatches between the
1339 // input size and transform lengths by either dropping extra input values or
1340 // using zeroes in place of missing input values as necessary. If the input data
1341 // has rank higher than the transform, the transform is applied for each valid
1342 // combination of the higher-ranking indices.
1343 //
1344 // The output contains complex64 values for FFT, IFFT, RFFT, and float values
1345 // for IRFFT. The rank of the output as well as the sizes of the dimensions
1346 // above the rank of the transform must match those of the input. Sizes of the
1347 // output's "fft_rank" innermost dimensions are expected to match the length of
1348 // the transform along respective axes with one exception: for RFFT, the output
1349 // is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
1350 // length(s) mismatch, the FFT output is trimmed to fit into the provided output
1351 // shape, or the output is padded with zero values appropriately.
1352 //
1353 // For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
1354 // input array will perform two transforms over the [][15][17] data in the sub
1355 // arrays [0][][] and [1][][], dropping the values along axis X and padding axis
1356 // Y with zeroes to create 16x16 working sets, and generating
1357 // complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
1358 // complex64[64][16][9] input array will use all input values and will produce
1359 // float[64][16][16] output.
1360 //
1361 // The implementation of the 1D transform for lengths, that are powers of 2, is
1362 // the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform
1363 // lengths, a straightforward, but slow, loop nest is used. The transforms of
1364 // higher ranks apply sets of 1D transforms along each axis. For example, the 2D
1365 // transform is computed by applying 1D transforms to each column followed by
1366 // applying 1D transforms to each row.
1367 //
1368 // In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
1369 // time, where Ni is the length of the transform's i-th dimension. However, for
1370 // dimension lengths, which are powers of 2, the run time along these dimensions
1371 // is reduced to log(Ni) in the summation, giving the runtime of
1372 // O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case.
1373 //
HandleFft(HloInstruction * fft)1374 Status HloEvaluator::HandleFft(HloInstruction* fft) {
1375   const FftType fft_type = fft->fft_type();
1376   std::vector<int64> fft_lengths = fft->fft_length();
1377   const int64 fft_rank = fft_lengths.size();
1378   const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
1379   const Shape& input_shape = input_literal.shape();
1380   const Shape& output_shape = fft->shape();
1381   Literal output_literal = Literal::CreateFromShape(output_shape);
1382 
1383   // Make fft_lengths[0] the minor-most dimension.
1384   absl::c_reverse(fft_lengths);
1385 
1386   TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape, fft_rank,
1387                                      fft_type, fft_lengths));
1388 
1389   const auto fft_strides = ComputeStrides(fft_lengths);
1390 
1391   // Working set size.
1392   const int64 fft_size = fft_strides[fft_rank];
1393 
1394   if (fft_size > 0) {
1395     // Linearized working data set.
1396     std::vector<complex128> data(fft_size);
1397 
1398     // Temporary buffer allocated once and used in 1D sweeps. For dimension
1399     // length values that are powers of 2, the buffer should be twice as large.
1400     int64 buffer_size = 0;
1401     for (auto len : fft_lengths) {
1402       int64 size = IsPowerOfTwo(static_cast<uint64>(len)) ? len * 2 : len;
1403       buffer_size = std::max(buffer_size, size);
1404     }
1405     std::vector<complex128> buffer(buffer_size);
1406 
1407     // Sizes of each axis of input and output literals.
1408     const auto input_lengths = GetDimensionLengths(input_literal);
1409     const auto output_lengths = GetDimensionLengths(output_literal);
1410 
1411     // Strides for generating linearized indices into multidimensional arrays.
1412     const auto input_strides = ComputeStrides(input_lengths, input_literal);
1413     const auto output_strides = ComputeStrides(output_lengths, output_literal);
1414 
1415     // Visit all elements in the dimensions with ranks above the FFT rank. For
1416     // each such element invoke the transform. Use separate indices for the
1417     // input and the output to allow different layouts.
1418     auto base_case = [&](int64 axis, int64 output_index, int64 input_index,
1419                          bool within_src_bounds) {
1420       if (axis == fft_rank - 1) {
1421         // Base case: copy the data from the input literal, apply the
1422         // transform, and copy the result to the output literal.
1423         CHECK(within_src_bounds);
1424         bool input_is_zero =
1425             CopyDataFromInput(input_literal, input_index, fft_rank, fft_type,
1426                               fft_size, fft_lengths, fft_strides, input_lengths,
1427                               input_strides, absl::MakeSpan(data));
1428         if (!input_is_zero) {
1429           // Make 1D sweeps along each transform axis.
1430           Sweep(fft_rank, fft_type, fft_lengths, fft_strides,
1431                 absl::MakeSpan(data), absl::MakeSpan(buffer));
1432         }
1433         CopyDataToOutput(absl::MakeSpan(data), output_index, fft_rank, fft_type,
1434                          fft_lengths, fft_strides, output_lengths,
1435                          output_strides, &output_literal);
1436         return true;
1437       }
1438       return false;
1439     };
1440     GenerateIndices(output_lengths, output_strides, input_lengths,
1441                     input_strides, input_shape.rank(), 0, 0, base_case);
1442   }
1443 
1444   evaluated_[fft] = std::move(output_literal);
1445   return Status::OK();
1446 }
1447 
1448 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
1449 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)1450 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
1451     const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
1452   int64 output_rank = output_shape.dimensions_size();
1453   std::vector<int64> index_base(output_rank, 0);
1454   std::vector<int64> index_count;
1455   index_count.reserve(output_rank);
1456   for (int64 i = 0; i < output_rank; i++) {
1457     bool is_output_batch_dim =
1458         !absl::c_binary_search(dim_numbers.offset_dims(), i);
1459     index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
1460   }
1461 
1462   return {std::move(index_base), std::move(index_count),
1463           std::vector<int64>(output_rank, 1)};
1464 }
1465 
1466 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
1467 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64 output_rank,absl::Span<const int64> slice_sizes,const GatherDimensionNumbers & dim_numbers)1468 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
1469     int64 output_rank, absl::Span<const int64> slice_sizes,
1470     const GatherDimensionNumbers& dim_numbers) {
1471   std::vector<int64> index_base(output_rank, 0);
1472   std::vector<int64> index_count(output_rank, 1);
1473   int64 slice_sizes_idx = 0;
1474   for (int64 i = 0; i < output_rank; i++) {
1475     bool is_output_window_dim =
1476         absl::c_binary_search(dim_numbers.offset_dims(), i);
1477     if (is_output_window_dim) {
1478       while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
1479                                    slice_sizes_idx)) {
1480         slice_sizes_idx++;
1481       }
1482       index_count[i] = slice_sizes[slice_sizes_idx++];
1483     }
1484   }
1485 
1486   return {std::move(index_base), std::move(index_count),
1487           std::vector<int64>(output_rank, 1)};
1488 }
1489 
1490 // This functor computes the contribution of start_indices to an input index
1491 // corresponding to an output index.  That is, given an output index I, it picks
1492 // out the batch indices in I and uses them to look up a starting index, G, from
1493 // the start indices tensor, and expands G into the input space according to
1494 // start_index_map.
1495 class OutputBatchIndexToInputIndex {
1496  public:
1497   // The constructor does some setup work that is amortized across all
1498   // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)1499   explicit OutputBatchIndexToInputIndex(
1500       const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
1501       const Shape& output_shape, const Literal* start_indices)
1502       : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
1503     for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
1504       output_dim_is_batch_dims_.push_back(
1505           !absl::c_binary_search(dim_numbers_.offset_dims(), i));
1506     }
1507 
1508     for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
1509       int64 index_of_input_dim_in_index_vector =
1510           std::distance(dim_numbers_.start_index_map().begin(),
1511                         absl::c_find(dim_numbers_.start_index_map(), i));
1512       if (index_of_input_dim_in_index_vector ==
1513           dim_numbers_.start_index_map_size()) {
1514         input_dim_value_to_index_vector_.push_back(-1);
1515       } else {
1516         input_dim_value_to_index_vector_.push_back(
1517             index_of_input_dim_in_index_vector);
1518       }
1519     }
1520 
1521     index_vector_index_.resize(start_indices_.shape().dimensions_size());
1522     input_index_.resize(input_shape.dimensions_size());
1523     int64 index_vector_size =
1524         start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
1525     index_vector_.resize(index_vector_size);
1526   }
1527 
1528   // Returns the contribution of start_indices to the input index corresponding
1529   // to output_index.  See gather_inner_loop_body.
1530   //
1531   // This is conceptually  a stateless transformation from output_index to the
1532   // gather input index, but:
1533   //
1534   //  - Instead of allocating memory to represent the gather input index on
1535   //    every invocation we reuse the same storage for the result
1536   //    (input_index_), mutating it in place.
1537   //  - Instead of allocating buffers for temporary values like
1538   //    index_vector_index_ and index_vector on every invocation, we reuse the
1539   //    same storage for all invocations.
1540   //
1541   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1542   StatusOr<absl::Span<const int64>> operator()(
1543       absl::Span<const int64> output_index) {
1544     PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
1545     TF_RETURN_IF_ERROR(FetchIndexVector());
1546     PropagateIndexVectorToInputIndex();
1547     return absl::Span<const int64>(input_index_);
1548   }
1549 
1550  private:
1551   // Propagates the batch dimensions from the output index into
1552   // index_vector_index_ by mutating index_vector_index_ in place.  Does not
1553   // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
1554   // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64> output_index)1555   void PropagateOutputIndexGatherDimsToIndexVectorIndex(
1556       absl::Span<const int64> output_index) {
1557     int64 index_vector_index_i = 0;
1558     for (int64 i = 0, e = output_index.size(); i < e; i++) {
1559       if (!output_dim_is_batch_dims_[i]) {
1560         continue;
1561       }
1562 
1563       if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
1564         index_vector_index_i++;
1565       }
1566 
1567       index_vector_index_[index_vector_index_i++] = output_index[i];
1568     }
1569   }
1570 
1571   // Populates index_vector_ by iterating over start_indices_ according to
1572   // index_vector_index_.
FetchIndexVector()1573   Status FetchIndexVector() {
1574     int64 index_vector_dim = dim_numbers_.index_vector_dim();
1575     for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
1576       index_vector_index_[index_vector_dim] = i;
1577       auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
1578       TF_RET_CHECK(start_index.has_value());
1579       index_vector_[i] = *start_index;
1580     }
1581     return Status::OK();
1582   }
1583 
1584   // Populates input_index_.
PropagateIndexVectorToInputIndex()1585   void PropagateIndexVectorToInputIndex() {
1586     for (int64 i = 0, e = input_index_.size(); i < e; i++) {
1587       if (input_dim_value_to_index_vector_[i] != -1) {
1588         input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
1589       }
1590 
1591       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1592       // remains 0, as set by the constructor.
1593     }
1594   }
1595 
1596   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1597   // the input index from the index vector.  See
1598   // PropagateIndexVectorToInputIndex.
1599   std::vector<int64> input_dim_value_to_index_vector_;
1600 
1601   // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
1602   // dimension.
1603   std::vector<bool> output_dim_is_batch_dims_;
1604 
1605   // The buffer into which we construct an index into start_indices_ to fetch
1606   // the index vector.
1607   std::vector<int64> index_vector_index_;
1608 
1609   // The index vector fetched from start_indices_.
1610   std::vector<int64> index_vector_;
1611 
1612   // The result computed by this functor.  operator() returns a Span into
1613   // this vector.
1614   std::vector<int64> input_index_;
1615 
1616   const GatherDimensionNumbers& dim_numbers_;
1617   const Literal& start_indices_;
1618 };
1619 
1620 // This functor computes the contribution of the offset indices in an output
1621 // index to an input index.  That is, given an output index I it picks out the
1622 // output offset indices in I and expands it into an index into the input shape.
1623 class OutputOffsetIndexToInputIndex {
1624  public:
1625   // The constructor does some setup work that is amortized across all
1626   // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)1627   explicit OutputOffsetIndexToInputIndex(
1628       const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
1629       const Shape& output_shape) {
1630     std::vector<int64> window_index_to_output_index;
1631     int64 output_index_count = 0;
1632     for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
1633       if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1634         window_index_to_output_index.push_back(output_index_count++);
1635       } else {
1636         output_index_count++;
1637       }
1638     }
1639 
1640     int64 window_dim_count = 0;
1641     for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
1642       if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1643         input_dim_value_to_output_index_.push_back(-1);
1644       } else {
1645         input_dim_value_to_output_index_.push_back(
1646             window_index_to_output_index[window_dim_count++]);
1647       }
1648     }
1649 
1650     input_index_.resize(input_shape.dimensions_size());
1651   }
1652 
1653   // Returns the contribution of the window indices to the input index
1654   // corresponding to output_index.  See gather_inner_loop_body.
1655   //
1656   // This is conceptually a stateless transformation from output_index to the
1657   // window input index, but instead of allocating memory to represent the
1658   // gather input index on every invocation we reuse the same storage for the
1659   // result (input_index_), mutating it in place.
1660   //
1661   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1662   StatusOr<absl::Span<const int64>> operator()(
1663       absl::Span<const int64> output_index) {
1664     PropagateOutputIndexWindowDimsToInputIndex(output_index);
1665     return absl::Span<const int64>(input_index_);
1666   }
1667 
1668   // Returns for a given 'input_dim' the corresponding output dimension index,
1669   // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64 input_dim)1670   int64 input_dim_value_to_output_index(int64 input_dim) {
1671     return input_dim_value_to_output_index_[input_dim];
1672   }
1673 
1674  private:
1675   // Propagates window dimensions from the output index to input_index_ by
1676   // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64> output_index)1677   void PropagateOutputIndexWindowDimsToInputIndex(
1678       absl::Span<const int64> output_index) {
1679     for (int64 i = 0, e = input_index_.size(); i < e; i++) {
1680       if (input_dim_value_to_output_index_[i] != -1) {
1681         input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
1682       }
1683 
1684       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1685       // remains 0, as set by the constructor.
1686     }
1687   }
1688 
1689   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1690   // the input index from the output index. See
1691   // PropagateOutputIndexWindowDimsToInputIndex.
1692   std::vector<int64> input_dim_value_to_output_index_;
1693 
1694   // The result computed by this functor.  operator() returns a Span into
1695   // this vector.
1696   std::vector<int64> input_index_;
1697 };
1698 
1699 // Reshapes the gather indices input to have a trailing degenerate `1` dimension
1700 // if necessary.  Hands over the ownership of the newly created literal (if
1701 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64 index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)1702 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
1703     int64 index_vector_dim, const Literal& start_indices,
1704     Literal* reshaped_start_indices) {
1705   if (start_indices.shape().dimensions_size() != index_vector_dim) {
1706     return std::cref(start_indices);
1707   }
1708 
1709   std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
1710                                start_indices.shape().dimensions().end());
1711   new_shape.push_back(1);
1712   TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
1713                       start_indices.Reshape(new_shape));
1714   return std::cref(*reshaped_start_indices);
1715 }
1716 
HandleGather(HloInstruction * gather)1717 Status HloEvaluator::HandleGather(HloInstruction* gather) {
1718   Literal result = Literal::CreateFromShape(gather->shape());
1719   const Shape& shape = gather->shape();
1720   const GatherDimensionNumbers& dim_numbers =
1721       gather->gather_dimension_numbers();
1722   const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
1723   Literal reshaped_start_indices;
1724   TF_ASSIGN_OR_RETURN(
1725       const Literal& start_indices,
1726       ReshapedGatherIndices(dim_numbers.index_vector_dim(),
1727                             GetEvaluatedLiteralFor(gather->operand(1)),
1728                             &reshaped_start_indices));
1729 
1730   // We iterate over the gather dimensions in the output shape in an outer loop
1731   // nest, and iterate over the window dimensions in the output shape in an
1732   // inner loop nest.
1733 
1734   ShapeUtil::IndexIterationSpace start_indices_iteration_space =
1735       IterationSpaceForOutputBatchIndices(shape, dim_numbers);
1736   ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
1737       IterationSpaceForOutputOffsetIndices(
1738           shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
1739 
1740   // Scratch buffers that hold an index in the output shape and the
1741   // corresponding index in the input shape.
1742   std::vector<int64> input_index(operand.shape().dimensions_size());
1743   std::vector<int64> output_index(gather->shape().dimensions_size());
1744   std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
1745 
1746   OutputBatchIndexToInputIndex output_batch_index_to_input_index(
1747       &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1748       /*output_shape=*/shape, &start_indices);
1749   OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
1750       gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1751       /*output_shape=*/shape);
1752 
1753   const Shape& operand_shape = operand.shape();
1754   if (ShapeUtil::IsZeroElementArray(operand_shape)) {
1755     evaluated_[gather] = std::move(result);
1756     return Status::OK();
1757   }
1758 
1759   auto gather_inner_loop_body =
1760       [&](absl::Span<const int64> output_window_index,
1761           absl::Span<const int64> input_gather_index,
1762           absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1763     TF_ASSIGN_OR_RETURN(
1764         absl::Span<const int64> input_window_index,
1765         output_offset_index_to_input_index(output_window_index));
1766     for (int i = 0, e = output_index.size(); i < e; i++) {
1767       output_index[i] = output_gather_index[i] + output_window_index[i];
1768       DCHECK_LT(output_index[i], shape.dimensions(i));
1769     }
1770     for (int i = 0, e = input_gather_index.size(); i < e; i++) {
1771       int64 output_dim =
1772           output_offset_index_to_input_index.input_dim_value_to_output_index(i);
1773       // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
1774       // we set the iteration index to 0, so for the purpose of the following
1775       // calculations we can consider the output dimension size to be 1.
1776       int64 output_dim_size =
1777           output_dim == -1 ? 1 : shape.dimensions(output_dim);
1778       // Clamp the gather index so that the gather region fits in the operand.
1779       // input_index_clamped[i] = clamp(input_gather_index[i], 0,
1780       //                                       operand_shape.dimensions(i) -
1781       //                                       output_dim_size);
1782       input_index_clamped[i] =
1783           std::min(operand_shape.dimensions(i) - output_dim_size,
1784                    std::max(int64{0}, input_gather_index[i]));
1785     }
1786     for (int i = 0, e = input_index.size(); i < e; i++) {
1787       input_index[i] = input_index_clamped[i] + input_window_index[i];
1788       DCHECK_GE(input_index[i], 0);
1789       DCHECK_LT(input_index[i], operand_shape.dimensions(i));
1790     }
1791     TF_RETURN_IF_ERROR(
1792         result.CopyElementFrom(operand, input_index, output_index));
1793     return true;
1794   };
1795 
1796   auto gather_outer_loop_body =
1797       [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1798     TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
1799                         output_batch_index_to_input_index(output_gather_index));
1800     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1801         shape, offset_indices_iteration_space,
1802         std::bind(gather_inner_loop_body, std::placeholders::_1,
1803                   input_gather_index, output_gather_index)));
1804     return true;
1805   };
1806 
1807   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1808       shape, start_indices_iteration_space, gather_outer_loop_body));
1809   evaluated_[gather] = std::move(result);
1810   return Status::OK();
1811 }
1812 
HandleBroadcast(HloInstruction * broadcast)1813 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
1814   const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
1815 
1816   TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
1817       << "broadcast dimensions is of size: " << broadcast->dimensions().size()
1818       << " and rank of operand_to_broadcast is: " << operand.shape().rank();
1819   // Checks that operand's dimensions are the same as the broadcast's
1820   // dimensions along the dimensions to be broadcasted.
1821   for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
1822     auto operand_dim_size = operand.shape().dimensions(i);
1823     auto broadcast_dim_size =
1824         broadcast->shape().dimensions(broadcast->dimensions(i));
1825     TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
1826         "Operand dimension %d is broadcast to output dimension %d, but the "
1827         "sizes of these two dims do not match (%d vs %d): %s",
1828         i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
1829         broadcast->ToString());
1830   }
1831 
1832   TF_ASSIGN_OR_RETURN(
1833       evaluated_[broadcast],
1834       operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
1835 
1836   return Status::OK();
1837 }
1838 
HandleAfterAll(HloInstruction * after_all)1839 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
1840   evaluated_[after_all] = LiteralUtil::CreateToken();
1841   return Status::OK();
1842 }
1843 
HandleAddDependency(HloInstruction * add_dependency)1844 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
1845   // AddDedendency just forwards its zero-th operand.
1846   evaluated_[add_dependency] =
1847       GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
1848   return Status::OK();
1849 }
1850 
HandleGetTupleElement(HloInstruction * get_tuple_element)1851 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
1852   const auto result_shape = get_tuple_element->shape();
1853   const int64 index = get_tuple_element->tuple_index();
1854 
1855   auto operand = get_tuple_element->operand(0);
1856   TF_ASSIGN_OR_RETURN(
1857       auto inferred_return_shape,
1858       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
1859   TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1860       << "return shape set to: " << ShapeUtil::HumanString(result_shape)
1861       << " but is inferred to be: "
1862       << ShapeUtil::HumanString(inferred_return_shape);
1863 
1864   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1865 
1866   evaluated_[get_tuple_element] =
1867       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
1868   return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
1869                                                 /*dest_shape_index=*/{},
1870                                                 /*src_shape_index=*/{index});
1871 }
1872 
HandleCopy(HloInstruction * copy)1873 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
1874   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
1875   evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
1876   return Status::OK();
1877 }
1878 
HandleCopyStart(HloInstruction * copy_start)1879 Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
1880   if (copy_start->user_count() != 1 ||
1881       copy_start->users().at(0)->opcode() != HloOpcode::kCopyDone) {
1882     return tensorflow::errors::FailedPrecondition(
1883         "Cannot evaluate a kCopyStart that doesn't have a single kCopyDone "
1884         "user.");
1885   }
1886 
1887   // The context in index {2} is undefined, but since we can't represent
1888   // undefined values using a Literal, we just use 0. This should be safe though
1889   // since we ensure that the only user of a kCopyStart is a kCopyDone which
1890   // consumes the context. Also note that MakeTuple copies its arguments, so
1891   // this is memory-safe.
1892   const Literal context_literal = LiteralUtil::CreateR0<uint32>(0);
1893   evaluated_[copy_start] = LiteralUtil::MakeTuple(
1894       {&GetEvaluatedLiteralFor(copy_start->operand(0)),
1895        &GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
1896   return Status::OK();
1897 }
1898 
HandleCopyDone(HloInstruction * copy_done)1899 Status HloEvaluator::HandleCopyDone(HloInstruction* copy_done) {
1900   const HloInstruction* operand = copy_done->operand(0);
1901   if (operand->opcode() != HloOpcode::kCopyStart) {
1902     return tensorflow::errors::FailedPrecondition(
1903         "Cannot evaluate a kCopyDone that doesn't have a kCopyStart as "
1904         "operand.");
1905   }
1906 
1907   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1908 
1909   evaluated_[copy_done] =
1910       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), /*index=*/0));
1911   TF_RETURN_IF_ERROR(evaluated_[copy_done].CopyFrom(operand_tuple_literal,
1912                                                     /*dest_shape_index=*/{},
1913                                                     /*src_shape_index=*/{0}));
1914   return Status::OK();
1915 }
1916 
HandleCall(HloInstruction * call)1917 Status HloEvaluator::HandleCall(HloInstruction* call) {
1918   auto* computation = call->to_apply();
1919   auto operands = call->operands();
1920 
1921   std::vector<const Literal*> arg_literals;
1922   arg_literals.reserve(operands.size());
1923   for (auto operand : operands) {
1924     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1925     arg_literals.push_back(&arg_literal);
1926   }
1927 
1928   HloEvaluator embedded_evaluator;
1929   embedded_evaluator.set_dynamic_dimension_inference(
1930       dynamic_dimension_inference_);
1931   TF_ASSIGN_OR_RETURN(Literal result,
1932                       embedded_evaluator.Evaluate(*computation, arg_literals));
1933 
1934   evaluated_[call] = std::move(result);
1935   return Status::OK();
1936 }
1937 
HandleFusion(HloInstruction * fusion)1938 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
1939   HloModuleConfig config;
1940   // Attach cloned computation to an empty HLO module so the existing ones are
1941   // not modified.
1942   HloModule empty_hlo_module("EmptyModuleForFusion", config);
1943   HloCloneContext context(&empty_hlo_module);
1944   auto cloned_fused_computation =
1945       fusion->fused_instructions_computation()->Clone(
1946           /*suffix=*/"clone_with_layout", &context);
1947   for (auto* instruction : cloned_fused_computation->instructions()) {
1948     if (!LayoutUtil::HasLayout(instruction->shape())) {
1949       LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
1950     }
1951   }
1952   auto readded_computation =
1953       empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
1954 
1955   auto operands = fusion->operands();
1956   std::vector<const Literal*> arg_literals;
1957   arg_literals.reserve(operands.size());
1958   for (auto operand : operands) {
1959     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1960     arg_literals.push_back(&arg_literal);
1961   }
1962 
1963   HloEvaluator embedded_evaluator;
1964   embedded_evaluator.set_dynamic_dimension_inference(
1965       dynamic_dimension_inference_);
1966   TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
1967                                           *readded_computation, arg_literals));
1968 
1969   evaluated_[fusion] = std::move(result);
1970   return Status::OK();
1971 }
1972 
HandleConditional(HloInstruction * conditional)1973 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
1974   const auto& branch_index_literal =
1975       GetEvaluatedLiteralFor(conditional->operand(0));
1976   int branch_index;
1977   if (conditional->operand(0)->shape().element_type() == PRED) {
1978     branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
1979   } else {
1980     branch_index = branch_index_literal.Get<int32>({});
1981     if (branch_index < 0 || branch_index >= conditional->branch_count()) {
1982       branch_index = conditional->branch_count() - 1;
1983     }
1984   }
1985   const auto& branch_computation_arg =
1986       GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
1987 
1988   HloEvaluator embedded_evaluator;
1989   embedded_evaluator.set_dynamic_dimension_inference(
1990       dynamic_dimension_inference_);
1991   TF_ASSIGN_OR_RETURN(Literal result,
1992                       embedded_evaluator.Evaluate(
1993                           *conditional->branch_computation(branch_index),
1994                           {&branch_computation_arg}));
1995 
1996   evaluated_[conditional] = std::move(result);
1997   return Status::OK();
1998 }
1999 
HandleSelect(HloInstruction * select)2000 Status HloEvaluator::HandleSelect(HloInstruction* select) {
2001   const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
2002   const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
2003   const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
2004 
2005   // If predicate is of scalar type, no element-wise selection would be needed.
2006   if (ShapeUtil::IsScalar(pred.shape())) {
2007     if (pred.Get<bool>({})) {
2008       evaluated_[select] = on_true.Clone();
2009     } else {
2010       evaluated_[select] = on_false.Clone();
2011     }
2012     return Status::OK();
2013   }
2014 
2015   return DefaultAction(select);
2016 }
2017 
HandleTupleSelect(HloInstruction * tuple_select)2018 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
2019   const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
2020   const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
2021   const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
2022 
2023   if (pred.Get<bool>({})) {
2024     evaluated_[tuple_select] = on_true.Clone();
2025   } else {
2026     evaluated_[tuple_select] = on_false.Clone();
2027   }
2028   return Status::OK();
2029 }
2030 
HandleWhile(HloInstruction * while_hlo)2031 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
2032   HloComputation* cond_comp = while_hlo->while_condition();
2033   HloComputation* body_comp = while_hlo->while_body();
2034   // Initialize the loop carried valued with the input to the While instruction.
2035   auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
2036   bool keep_going = true;
2037   int64 iteration_count = 0;
2038   HloEvaluator cond_evaluator(max_loop_iterations_);
2039   cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
2040   HloEvaluator loop_body_evaluator(max_loop_iterations_);
2041   loop_body_evaluator.set_dynamic_dimension_inference(
2042       dynamic_dimension_inference_);
2043   while (keep_going) {
2044     if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
2045       return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
2046                              while_hlo->name(), max_loop_iterations_);
2047     }
2048     TF_ASSIGN_OR_RETURN(auto cond_val,
2049                         cond_evaluator.Evaluate(*cond_comp, {&lcv}));
2050     keep_going = cond_val.GetFirstElement<bool>();
2051     if (keep_going) {
2052       TF_ASSIGN_OR_RETURN(auto body_val,
2053                           loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
2054       VLOG(3) << "Loop iteration result: " << body_val.ToString();
2055       lcv = std::move(body_val);
2056       cond_evaluator.ResetVisitStates();
2057       loop_body_evaluator.ResetVisitStates();
2058     }
2059   }
2060   evaluated_[while_hlo] = std::move(lcv);
2061   return Status::OK();
2062 }
2063 
2064 namespace {
2065 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar)2066 Literal ExtractLiteralFromIndexPositions(const Literal& from,
2067                                          absl::Span<int64 const> indices,
2068                                          bool extract_as_scalar) {
2069   if (extract_as_scalar) {
2070     return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
2071   }
2072   // We use a InlinedVector here because we need to convert it to an
2073   // absl::Span later, and this would not work with std::vector<bool>.
2074   absl::InlinedVector<NativeT, 10> values;
2075   for (int64 index : indices) {
2076     values.push_back(from.Get<NativeT>({index}));
2077   }
2078   return LiteralUtil::CreateR1<NativeT>(values);
2079 }
2080 
ExtractFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar=false)2081 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
2082                                             absl::Span<int64 const> indices,
2083                                             bool extract_as_scalar = false) {
2084   if (extract_as_scalar) {
2085     CHECK_EQ(indices.size(), 1);
2086   }
2087   PrimitiveType type = from.shape().element_type();
2088   switch (type) {
2089     case PRED: {
2090       return ExtractLiteralFromIndexPositions<bool>(from, indices,
2091                                                     extract_as_scalar);
2092     }
2093     case U8: {
2094       return ExtractLiteralFromIndexPositions<uint8>(from, indices,
2095                                                      extract_as_scalar);
2096     }
2097     case S8: {
2098       return ExtractLiteralFromIndexPositions<int8>(from, indices,
2099                                                     extract_as_scalar);
2100     }
2101     case BF16: {
2102       return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
2103                                                         extract_as_scalar);
2104     }
2105     case F16: {
2106       return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
2107                                                            extract_as_scalar);
2108     }
2109     case U16: {
2110       return ExtractLiteralFromIndexPositions<uint16>(from, indices,
2111                                                       extract_as_scalar);
2112     }
2113     case S16: {
2114       return ExtractLiteralFromIndexPositions<int16>(from, indices,
2115                                                      extract_as_scalar);
2116     }
2117     case F32: {
2118       return ExtractLiteralFromIndexPositions<float>(from, indices,
2119                                                      extract_as_scalar);
2120     }
2121     case U32: {
2122       return ExtractLiteralFromIndexPositions<uint32>(from, indices,
2123                                                       extract_as_scalar);
2124     }
2125     case S32: {
2126       return ExtractLiteralFromIndexPositions<int32>(from, indices,
2127                                                      extract_as_scalar);
2128     }
2129     case F64: {
2130       return ExtractLiteralFromIndexPositions<double>(from, indices,
2131                                                       extract_as_scalar);
2132     }
2133     case C64: {
2134       return ExtractLiteralFromIndexPositions<std::complex<float>>(
2135           from, indices, extract_as_scalar);
2136     }
2137     case U64: {
2138       return ExtractLiteralFromIndexPositions<uint64>(from, indices,
2139                                                       extract_as_scalar);
2140     }
2141     case S64: {
2142       return ExtractLiteralFromIndexPositions<int64>(from, indices,
2143                                                      extract_as_scalar);
2144     }
2145     case C128: {
2146       return ExtractLiteralFromIndexPositions<std::complex<double>>(
2147           from, indices, extract_as_scalar);
2148     }
2149     default:
2150       return InvalidArgument("Unsupported type for Sort: %s",
2151                              PrimitiveType_Name(type));
2152   }
2153 }
2154 }  // namespace
2155 
HandleSort(HloInstruction * sort)2156 Status HloEvaluator::HandleSort(HloInstruction* sort) {
2157   TF_RET_CHECK(sort->operand_count() >= 1)
2158       << "Expected at least 1 operand for sort";
2159   for (int64 i = 1; i < sort->operand_count(); ++i) {
2160     TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
2161                                            sort->operand(i)->shape()))
2162         << "All Sort operands must have the same dimensions";
2163   }
2164 
2165   if (VLOG_IS_ON(3)) {
2166     for (int64 i = 0; i < sort->operand_count(); ++i) {
2167       VLOG(3) << "HandleSort operand " << i << " literal: "
2168               << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
2169     }
2170   }
2171   Shape key_shape = sort->operand(0)->shape();
2172   auto rank = key_shape.rank();
2173   std::vector<Literal> result_literals;
2174   result_literals.reserve(sort->operand_count());
2175   for (int64 i = 0; i < sort->operand_count(); ++i) {
2176     result_literals.emplace_back(sort->operand(i)->shape());
2177   }
2178   std::vector<int64> zero_base(rank, 0);
2179   std::vector<int64> increment(rank, 1);
2180   int64 sort_dim = sort->dimensions(0);
2181   int64 sort_dim_elements = key_shape.dimensions(sort_dim);
2182   increment[sort_dim] = sort_dim_elements;
2183   HloEvaluator embedded_evaluator(max_loop_iterations_);
2184   // Iterate through each dimension except 'sort_dim'.
2185   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2186       key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
2187       [&](absl::Span<const int64> indices) -> StatusOr<bool> {
2188         // Extract a slice from each operand literal that corresponds to
2189         // exactly the row in dimension 'sort_dim'.
2190         std::vector<int64> limit_indices(indices.begin(), indices.end());
2191         absl::c_for_each(limit_indices, [](int64& index) { ++index; });
2192         limit_indices[sort_dim] = sort_dim_elements;
2193         std::vector<Literal> literals_to_sort;
2194         literals_to_sort.reserve(sort->operand_count());
2195         for (int64 i = 0; i < sort->operand_count(); ++i) {
2196           TF_ASSIGN_OR_RETURN(auto literal_to_sort,
2197                               GetEvaluatedLiteralFor(sort->operand(i))
2198                                   .Slice(indices, limit_indices)
2199                                   .Reshape({sort_dim_elements}));
2200           literals_to_sort.push_back(std::move(literal_to_sort));
2201         }
2202         std::vector<int64> indices_to_sort(sort_dim_elements);
2203         std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
2204         Status compare_status = Status::OK();
2205         auto comparator = [sort, &compare_status, &embedded_evaluator,
2206                            &literals_to_sort](int64 a, int64 b) {
2207           std::vector<Literal> literals;
2208           literals.reserve(2 * sort->operand_count());
2209           for (int64 i = 0; i < sort->operand_count(); ++i) {
2210             auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
2211                                                  /*extract_as_scalar=*/true);
2212             if (!lhs.ok()) {
2213               compare_status = lhs.status();
2214               return false;
2215             }
2216             literals.push_back(std::move(lhs.ValueOrDie()));
2217             auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
2218                                                  /*extract_as_scalar=*/true);
2219             if (!rhs.ok()) {
2220               compare_status = rhs.status();
2221               return false;
2222             }
2223             literals.push_back(std::move(rhs.ValueOrDie()));
2224           }
2225           std::vector<const Literal*> literal_ptrs;
2226           absl::c_transform(literals, std::back_inserter(literal_ptrs),
2227                             [](const Literal& literal) { return &literal; });
2228 
2229           auto computed_result =
2230               embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
2231           // Clear visit states so that we can use the evaluator again
2232           // on the same computation.
2233           embedded_evaluator.ResetVisitStates();
2234           if (!computed_result.ok()) {
2235             compare_status = computed_result.status();
2236             return false;
2237           }
2238           return computed_result.ValueOrDie().Get<bool>({});
2239         };
2240         if (Cast<HloSortInstruction>(sort)->is_stable()) {
2241           std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
2242                            comparator);
2243         } else {
2244           std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
2245         }
2246         if (!compare_status.ok()) {
2247           return compare_status;
2248         }
2249         std::vector<int64> slice_dimensions(rank, 1);
2250         slice_dimensions[sort_dim] = sort_dim_elements;
2251         std::vector<int64> start_indices(rank, 0);
2252         for (int64 i = 0; i < sort->operand_count(); ++i) {
2253           TF_ASSIGN_OR_RETURN(
2254               Literal sorted_literal,
2255               ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
2256           TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
2257                               sorted_literal.Reshape(slice_dimensions));
2258           TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
2259               sorted_literal_reshaped, start_indices, indices,
2260               slice_dimensions));
2261         }
2262         return true;
2263       }));
2264 
2265   if (sort->operand_count() == 1) {
2266     evaluated_[sort] = std::move(result_literals[0]);
2267   } else {
2268     std::vector<const Literal*> literal_ptrs;
2269     absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
2270                       [](const Literal& literal) { return &literal; });
2271 
2272     Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
2273     VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
2274 
2275     evaluated_[sort] = std::move(result_tuple);
2276   }
2277   return Status::OK();
2278 }
2279 
IsScalarAdd(HloComputation * computation)2280 static bool IsScalarAdd(HloComputation* computation) {
2281   HloInstruction* instruction = computation->root_instruction();
2282   if (instruction->opcode() == HloOpcode::kAdd &&
2283       computation->num_parameters() == 2) {
2284     const HloInstruction* lhs = instruction->operand(0);
2285     const HloInstruction* rhs = instruction->operand(1);
2286     return lhs->opcode() == HloOpcode::kParameter &&
2287            ShapeUtil::IsScalar(lhs->shape()) &&
2288            rhs->opcode() == HloOpcode::kParameter &&
2289            ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
2290   }
2291   return false;
2292 }
2293 
2294 // Run a single step of an inner loop while running reduction, which applies
2295 // the user-provided computation on the accumulator and the output element
2296 // (until the reduction is completed, the output element is also used as
2297 // an accumulator).
PerformReductionStep(bool is_tuple,absl::Span<const int64> input_index,absl::Span<const int64> output_index,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * computation,HloEvaluator * embedded_evaluator)2298 static StatusOr<bool> PerformReductionStep(
2299     bool is_tuple, absl::Span<const int64> input_index,
2300     absl::Span<const int64> output_index,
2301     absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2302     HloComputation* computation, HloEvaluator* embedded_evaluator) {
2303   int num_args = results.size();
2304 
2305   absl::InlinedVector<Literal, 1> arg_values;
2306   arg_values.reserve(num_args);
2307   absl::InlinedVector<Literal, 1> accumulators;
2308   accumulators.reserve(num_args);
2309   for (int64 i = 0; i < num_args; ++i) {
2310     arg_values.emplace_back(
2311         ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2312     accumulators.emplace_back(
2313         ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2314 
2315     TF_RETURN_IF_ERROR(
2316         arg_values[i].CopyElementFrom(*input_args[i], input_index, {}));
2317     TF_RETURN_IF_ERROR(
2318         accumulators[i].CopyElementFrom(results[i], output_index, {}));
2319   }
2320 
2321   // Evaluate computation with specified literal operands.
2322   absl::InlinedVector<Literal*, 2> embedded_operands;
2323   for (Literal& accumulator : accumulators) {
2324     embedded_operands.push_back(&accumulator);
2325   }
2326   for (Literal& local_input : arg_values) {
2327     embedded_operands.push_back(&local_input);
2328   }
2329 
2330   TF_ASSIGN_OR_RETURN(
2331       Literal computed_result,
2332       embedded_evaluator->Evaluate(*computation, embedded_operands));
2333 
2334   // Clear visit states so that we can use the evaluator again on the same
2335   // computation.
2336   embedded_evaluator->ResetVisitStates();
2337 
2338   if (is_tuple) {
2339     std::vector<Literal> computed_results = computed_result.DecomposeTuple();
2340     for (int64 i = 0; i < num_args; ++i) {
2341       TF_RETURN_IF_ERROR(
2342           results[i].CopyElementFrom(computed_results[i], {}, output_index));
2343     }
2344   } else {
2345     TF_RETURN_IF_ERROR(
2346         results[0].CopyElementFrom(computed_result, {}, output_index));
2347   }
2348 
2349   return true;
2350 }
2351 
GenerateReduceOutputElement(bool is_tuple,absl::Span<const int64> output_index,absl::Span<const Literal * const> init_values,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * function,HloEvaluator * embedded_evaluator,absl::Span<const int64> arg_dim_steps,absl::Span<const int64> arg_dim_counts,absl::Span<const int64> result_to_arg_index)2352 static StatusOr<bool> GenerateReduceOutputElement(
2353     bool is_tuple, absl::Span<const int64> output_index,
2354 
2355     absl::Span<const Literal* const> init_values,
2356     absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2357 
2358     HloComputation* function, HloEvaluator* embedded_evaluator,
2359 
2360     absl::Span<const int64> arg_dim_steps,
2361     absl::Span<const int64> arg_dim_counts,
2362     absl::Span<const int64> result_to_arg_index) {
2363   bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) &&
2364                       IsScalarAdd(function) && !is_tuple;
2365 
2366   const Shape& arg_shape = input_args[0]->shape();
2367   absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2368   std::vector<int64> base(arg_dimensions.size());
2369   for (int64 i = 0; i < output_index.size(); ++i) {
2370     base[result_to_arg_index[i]] = output_index[i];
2371   }
2372 
2373   for (int64 i = 0; i < results.size(); ++i) {
2374     TF_RETURN_IF_ERROR(
2375         results[i].CopyElementFrom(*init_values[i], {}, output_index));
2376   }
2377 
2378   if (use_fast_add) {
2379     double computed_result = *init_values[0]->GetAsDouble({});
2380     auto reduction_step =
2381         [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
2382       double argument = *input_args[0]->GetAsDouble(input_index);
2383       computed_result += argument;
2384       return true;
2385     };
2386     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2387         arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step));
2388     TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result));
2389     return true;
2390   }
2391 
2392   // Iterates only over reduced shape, as counts and steps are set to zero
2393   // for all non-reduced dimensions.
2394   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2395       arg_shape, base, arg_dim_counts, arg_dim_steps,
2396       [&](absl::Span<const int64> input_index) {
2397         return PerformReductionStep(is_tuple, input_index, output_index,
2398                                     input_args, results, function,
2399                                     embedded_evaluator);
2400       }));
2401   return true;
2402 }
2403 
HandleReduce(HloInstruction * instr)2404 Status HloEvaluator::HandleReduce(HloInstruction* instr) {
2405   HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
2406   int64 num_args = reduce->inputs().size();
2407   absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
2408   HloComputation* function = reduce->to_apply();
2409 
2410   absl::InlinedVector<const Shape*, 1> operand_shapes;
2411   for (const HloInstruction* operand : reduce->operands()) {
2412     operand_shapes.push_back(&operand->shape());
2413   }
2414   TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
2415                       ShapeInference::InferReduceShape(
2416                           operand_shapes, dimensions_to_reduce,
2417                           /*to_apply=*/function->ComputeProgramShape()));
2418   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(),
2419                                                         inferred_return_shape))
2420       << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
2421       << " but is inferred to be: "
2422       << ShapeUtil::HumanString(inferred_return_shape);
2423 
2424   absl::InlinedVector<const Literal*, 1> input_args(num_args);
2425   absl::InlinedVector<const Literal*, 1> init_values(num_args);
2426   for (int64 i = 0; i < num_args; ++i) {
2427     input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]);
2428     VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString();
2429     init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]);
2430     VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString();
2431     TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape()));
2432   }
2433 
2434   // All args and results have the same dimensions, so pick an arbitrary one.
2435   const Shape& arg_shape = input_args[0]->shape();
2436   const Shape& out_shape = inferred_return_shape;
2437   bool is_tuple = out_shape.IsTuple();
2438   const Shape& output_shape = inferred_return_shape.IsTuple()
2439                                   ? inferred_return_shape.tuple_shapes(0)
2440                                   : inferred_return_shape;
2441 
2442   absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2443 
2444   // All increments are set to 0.
2445   std::vector<int64> arg_dim_steps(arg_dimensions.size());
2446 
2447   // All counts are set to 0.
2448   std::vector<int64> arg_dim_counts(arg_dimensions.size());
2449 
2450   // Set steps and counts for reduced dimensions.
2451   // This avoids iterating over non-reduced dimensions, as their step
2452   // and count is set to zero.
2453   for (const int64 dim : dimensions_to_reduce) {
2454     arg_dim_steps[dim] = 1;
2455     arg_dim_counts[dim] = arg_dimensions[dim];
2456   }
2457 
2458   // Map each dimension in the result to a dimension in arg that isn't
2459   // being reduced.
2460   std::vector<int64> result_to_arg_index;
2461   for (int64 i = 0; i < arg_dimensions.size(); ++i) {
2462     if (arg_dim_steps[i] == 0) {
2463       result_to_arg_index.push_back(i);
2464     }
2465   }
2466 
2467   HloEvaluator embedded_evaluator(max_loop_iterations_);
2468   absl::InlinedVector<Literal, 1> results(num_args);
2469   for (int64 i = 0; i < num_args; ++i) {
2470     results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape);
2471   }
2472 
2473   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2474       output_shape, [&](absl::Span<const int64> output_index) {
2475         return GenerateReduceOutputElement(
2476             is_tuple, output_index, init_values, input_args,
2477             absl::Span<Literal>(results), function, &embedded_evaluator,
2478             arg_dim_steps, arg_dim_counts, result_to_arg_index);
2479       }));
2480 
2481   if (is_tuple) {
2482     Literal tuple_result(inferred_return_shape);
2483     for (int64 i = 0; i < num_args; ++i) {
2484       TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
2485     }
2486     evaluated_[reduce] = std::move(tuple_result);
2487   } else {
2488     CHECK_EQ(results.size(), 1);
2489     evaluated_[reduce] = std::move(results[0]);
2490   }
2491   if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) {
2492     TF_ASSIGN_OR_RETURN(evaluated_[reduce],
2493                         evaluated_[reduce].ConvertToShape(reduce->shape()));
2494   }
2495   return Status::OK();
2496 }
2497 
HandleReduceWindow(HloInstruction * hlo)2498 Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
2499   // Here we delegate the handling to the typed visitor class, instantiated by
2500   // using the type of the first input of ReduceWindow. The support for the
2501   // variadic case inside the typed_visitor is made to not use the template
2502   // parameter so it doesn't really matter which type is used to instantiate it
2503   // here. We choose not to move the implementation for handle ReduceWindow
2504   // from the typed visitor to here because we need to reuse the
2505   // IterateThroughWindow method, which is defined and only avaiable inside the
2506   // typed visitor.
2507   if (hlo->shape().IsTuple()) {
2508     return hlo->Visit(
2509         typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
2510   } else {
2511     return DefaultAction(hlo);
2512   }
2513 }
2514 
HandleCustomCall(HloInstruction * custom_call)2515 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
2516   if (!custom_call_handler_) {
2517     // No handler is registered; this means custom-calls are not allowed.
2518     return DefaultAction(custom_call);
2519   }
2520 
2521   // Evaluate input operands so the handler has access to the operand data.
2522   std::vector<const Literal*> operands;
2523   operands.reserve(custom_call->operand_count());
2524   for (const HloInstruction* operand : custom_call->operands()) {
2525     operands.push_back(&GetEvaluatedLiteralFor(operand));
2526   }
2527 
2528   // Synchronously issue the handler to populate the instruction output literal.
2529   TF_ASSIGN_OR_RETURN(
2530       auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
2531 
2532   evaluated_[custom_call] = std::move(output);
2533   return Status::OK();
2534 }
2535 
Preprocess(HloInstruction * hlo)2536 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
2537   VLOG(2) << "About to visit HLO: " << hlo->ToString();
2538   return ShapeUtil::ValidateShape(hlo->shape());
2539 }
2540 
Postprocess(HloInstruction * hlo)2541 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
2542   VLOG(2) << "Finished visiting " << hlo->ToString()
2543           << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
2544   // Out of convenience the literal may have been produced with a different
2545   // layout. Relayout as indicated by the HLO instruction.
2546   if (!Layout::Equal().MinorToMajorOnly()(
2547           GetEvaluatedLiteralFor(hlo).shape().layout(),
2548           hlo->shape().layout())) {
2549     evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
2550   }
2551   return Status::OK();
2552 }
2553 
2554 namespace {
2555 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64 m,int64 n,int64 k,int32 transpose_lhs,int32 transpose_rhs)> & impl_fn)2556 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
2557     const Array2D<T>& lhs, const Array2D<T>& rhs,
2558     const std::function<void(
2559         const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n,
2560         int64 k, int32 transpose_lhs, int32 transpose_rhs)>& impl_fn) {
2561   CHECK_EQ(lhs.width(), rhs.height());
2562   int m = lhs.height();
2563   int n = rhs.width();
2564   int k = lhs.width();
2565   auto result = absl::make_unique<Array2D<T>>(m, n);
2566   // Because Eigen is a header-oriented library, make sure that the Eigen code
2567   // is the same as the code used by the CPU backend (otherwise the linker will
2568   // randomly pick *some* definition).
2569   impl_fn(
2570       /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
2571       k,
2572       /*transpose_lhs=*/0,
2573       /*transpose_rhs=*/0);
2574   return result;
2575 }
2576 }  // namespace
2577 
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)2578 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
2579     const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
2580   return MatmulArray2DImpl<Eigen::half>(
2581       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
2582 }
2583 
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)2584 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
2585     const Array2D<float>& lhs, const Array2D<float>& rhs) {
2586   return MatmulArray2DImpl<float>(
2587       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
2588 }
2589 
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)2590 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
2591     const Array2D<double>& lhs, const Array2D<double>& rhs) {
2592   return MatmulArray2DImpl<double>(
2593       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
2594 }
2595 
MatmulArray2D(const Array2D<std::complex<float>> & lhs,const Array2D<std::complex<float>> & rhs)2596 std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
2597     const Array2D<std::complex<float>>& lhs,
2598     const Array2D<std::complex<float>>& rhs) {
2599   return MatmulArray2DImpl<std::complex<float>>(
2600       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
2601 }
2602 
MatmulArray2D(const Array2D<std::complex<double>> & lhs,const Array2D<std::complex<double>> & rhs)2603 std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
2604     const Array2D<std::complex<double>>& lhs,
2605     const Array2D<std::complex<double>>& rhs) {
2606   return MatmulArray2DImpl<std::complex<double>>(
2607       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
2608 }
2609 
MatmulArray2D(const Array2D<int32> & lhs,const Array2D<int32> & rhs)2610 std::unique_ptr<Array2D<int32>> HloEvaluator::MatmulArray2D(
2611     const Array2D<int32>& lhs, const Array2D<int32>& rhs) {
2612   return MatmulArray2DImpl<int32>(
2613       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulS32);
2614 }
2615 
2616 }  // namespace xla
2617