• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <complex>
20 #include <cstdlib>
21 #include <functional>
22 #include <iterator>
23 #include <string>
24 #include <type_traits>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/index_util.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/hlo_query.h"
43 #include "tensorflow/compiler/xla/service/shape_inference.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/stream_executor/lib/statusor.h"
56 
57 namespace xla {
58 
59 namespace {
60 
61 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)62 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
63                           LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
64   std::function<bool(OperandT, OperandT)> compare_op;
65   switch (direction) {
66     case ComparisonDirection::kEq:
67       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
68         return lhs_el == rhs_el;
69       };
70       break;
71     case ComparisonDirection::kNe:
72       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
73         return lhs_el != rhs_el;
74       };
75       break;
76     case ComparisonDirection::kGe:
77       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
78         return lhs_el >= rhs_el;
79       };
80       break;
81     case ComparisonDirection::kGt:
82       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
83         return lhs_el > rhs_el;
84       };
85       break;
86     case ComparisonDirection::kLe:
87       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
88         return lhs_el <= rhs_el;
89       };
90       break;
91     case ComparisonDirection::kLt:
92       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
93         return lhs_el < rhs_el;
94       };
95       break;
96   }
97 
98   Literal result(shape);
99   TF_RETURN_IF_ERROR(
100       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
101         return compare_op(lhs_literal.Get<OperandT>(multi_index),
102                           rhs_literal.Get<OperandT>(multi_index));
103       }));
104 
105   return std::move(result);
106 }
107 
108 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)109 StatusOr<Literal> Compare<complex64>(const Shape& shape,
110                                      ComparisonDirection direction,
111                                      LiteralSlice lhs_literal,
112                                      LiteralSlice rhs_literal) {
113   std::function<bool(complex64, complex64)> compare_op;
114   switch (direction) {
115     case ComparisonDirection::kEq:
116       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
117         return lhs_el == rhs_el;
118       };
119       break;
120     case ComparisonDirection::kNe:
121       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
122         return lhs_el != rhs_el;
123       };
124       break;
125     default:
126       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
127                  << ComparisonDirectionToString(direction);
128   }
129 
130   Literal result(shape);
131   TF_RETURN_IF_ERROR(
132       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
133         return compare_op(lhs_literal.Get<complex64>(multi_index),
134                           rhs_literal.Get<complex64>(multi_index));
135       }));
136 
137   return std::move(result);
138 }
139 
140 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)141 StatusOr<Literal> Compare<complex128>(const Shape& shape,
142                                       ComparisonDirection direction,
143                                       LiteralSlice lhs_literal,
144                                       LiteralSlice rhs_literal) {
145   std::function<bool(complex128, complex128)> compare_op;
146   switch (direction) {
147     case ComparisonDirection::kEq:
148       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
149         return lhs_el == rhs_el;
150       };
151       break;
152     case ComparisonDirection::kNe:
153       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
154         return lhs_el != rhs_el;
155       };
156       break;
157     default:
158       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
159                  << ComparisonDirectionToString(direction);
160   }
161 
162   Literal result(shape);
163   TF_RETURN_IF_ERROR(
164       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
165         return compare_op(lhs_literal.Get<complex128>(multi_index),
166                           rhs_literal.Get<complex128>(multi_index));
167       }));
168 
169   return std::move(result);
170 }
171 
172 }  // namespace
173 
174 // Note that unsupported types by the typed visitor does not necessarily imply
175 // the non-typed HloEvaluator (parent evaluator) would not support them either
176 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
177 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
178 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64_t max_loop_iterations)179 HloEvaluator::HloEvaluator(int64_t max_loop_iterations)
180     : max_loop_iterations_(max_loop_iterations) {
181   typed_visitors_[PRED] =
182       absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
183   typed_visitors_[U8] =
184       absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
185   typed_visitors_[U16] =
186       absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
187   typed_visitors_[U32] =
188       absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
189   typed_visitors_[U64] =
190       absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
191   typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
192   typed_visitors_[S16] =
193       absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
194   typed_visitors_[S32] =
195       absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
196   typed_visitors_[S64] =
197       absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
198   typed_visitors_[F16] =
199       absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
200   typed_visitors_[F32] =
201       absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
202   typed_visitors_[F64] =
203       absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
204   typed_visitors_[C64] =
205       absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
206   typed_visitors_[C128] =
207       absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
208 
209   // Most of the evaluator computations we use don't support BF16 (e.g.,
210   // std::ceil, std::tanh). To make evaluator work with BF16, we set all
211   // elementwise computations to be done in F32 and do BF16<->F32 conversion
212   // around the input and the output of the computations.
213   typed_visitors_[BF16] =
214       absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
215 
216   typed_visitors_[TUPLE] =
217       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
218         return Unimplemented(
219             "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
220       });
221   typed_visitors_[OPAQUE_TYPE] =
222       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
223         return Unimplemented(
224             "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE.");
225       });
226   typed_visitors_[TOKEN] =
227       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
228         return Unimplemented(
229             "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
230       });
231 }
232 
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)233 StatusOr<Literal> HloEvaluator::Evaluate(
234     const HloComputation& computation,
235     absl::Span<const Literal* const> arg_literals) {
236   CHECK(computation.parent() != nullptr);
237   XLA_VLOG_LINES(
238       2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
239 
240   if (arg_literals.size() != computation.num_parameters()) {
241     return InvalidArgument(
242         "Expected %d argument%s, but got %d.", computation.num_parameters(),
243         computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
244   }
245   for (int64_t i = 0; i < arg_literals.size(); ++i) {
246     const auto& computation_shape =
247         computation.parameter_instruction(i)->shape();
248     const auto& arg_shape = arg_literals[i]->shape();
249     if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
250                                                    arg_shape)) {
251       return InvalidArgument(
252           "Shape mismatch at parameter %d. Computation expected %s, but arg "
253           "was %s.",
254           i, ShapeUtil::HumanStringWithLayout(computation_shape),
255           ShapeUtil::HumanStringWithLayout(arg_shape));
256     }
257   }
258 
259   evaluated_.clear();
260   arg_literals_.clear();
261   for (const auto& literal_ptr : arg_literals) {
262     arg_literals_.push_back(&*literal_ptr);
263   }
264 
265   // Re-seed RNG, either from the configuration's seed or a monotonic
266   // per-evaluator seed (which prevents two evaluators from returning the same
267   // random sequence).
268   if (computation.parent()->config().seed()) {
269     seed_ = computation.parent()->config().seed();
270   } else {
271     // Start global_seed at a (true) random value.
272     static std::atomic<uint64> global_seed{std::random_device()()};
273     seed_ = global_seed.fetch_add(1);
274   }
275   engine_.seed(seed_);
276 
277   TF_RETURN_IF_ERROR(computation.Accept(this));
278 
279   if (VLOG_IS_ON(100)) {
280     for (const HloInstruction* instr : computation.instructions()) {
281       VLOG(100) << instr->name() << " = " << GetEvaluatedLiteralFor(instr);
282     }
283   }
284 
285   return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
286 }
287 
Evaluate(HloInstruction * instruction)288 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
289   // If the instruction is a kCopyDone, simply find the argument that it is
290   // copied from.
291   while (instruction->opcode() == HloOpcode::kCopyDone) {
292     if (instruction->operand(0)->opcode() != HloOpcode::kCopyStart) {
293       return tensorflow::errors::FailedPrecondition(
294           "kCopyDone has an argument different than a kCopyStart.");
295     }
296     instruction = instruction->mutable_operand(0)->mutable_operand(0);
297   }
298   if (instruction->opcode() == HloOpcode::kParameter) {
299     return tensorflow::errors::FailedPrecondition(
300         "Cannot evaluate a parameter.");
301   }
302   if (!hlo_query::AllOperandsAreConstants(*instruction)) {
303     return tensorflow::errors::FailedPrecondition(
304         "Not all operands are constants.");
305   }
306 
307   arg_literals_.clear();
308   evaluated_.clear();
309 
310   TF_RETURN_IF_ERROR(Preprocess(instruction));
311   TF_RETURN_IF_ERROR(instruction->Visit(this));
312   TF_RETURN_IF_ERROR(Postprocess(instruction));
313   return GetEvaluatedLiteralFor(instruction).Clone();
314 }
315 
TryEvaluate(HloInstruction * instruction,Literal * result)316 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
317   CHECK(result != nullptr);
318   auto result_or = Evaluate(instruction);
319   if (!result_or.ok()) {
320     VLOG(1) << "TryEvaluate failed:" << result_or.status();
321     return false;
322   }
323 
324   *result = result_or.ConsumeValueOrDie();
325   return true;
326 }
327 
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)328 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
329     const HloInstruction* instruction,
330     const std::unordered_map<const HloInstruction*, const Literal*>&
331         substitutions) {
332   std::vector<std::unique_ptr<HloInstruction>> owned_operands;
333   for (const HloInstruction* operand : instruction->operands()) {
334     auto it = substitutions.find(operand);
335     if (it == substitutions.end()) {
336       owned_operands.push_back(operand->Clone());
337     } else {
338       owned_operands.push_back(
339           HloInstruction::CreateConstant(it->second->Clone()));
340     }
341   }
342 
343   std::vector<HloInstruction*> operands;
344   operands.reserve(owned_operands.size());
345   for (auto& operand : owned_operands) {
346     operands.push_back(operand.get());
347   }
348 
349   std::unique_ptr<HloInstruction> cloned_instruction =
350       instruction->CloneWithNewOperands(instruction->shape(), operands);
351   auto result = Evaluate(cloned_instruction.get());
352 
353   return result;
354 }
355 
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)356 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
357     HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
358   std::unique_ptr<HloInstruction> lhs_instr =
359       HloInstruction::CreateConstant(lhs.Clone());
360   std::unique_ptr<HloInstruction> rhs_instr =
361       HloInstruction::CreateConstant(rhs.Clone());
362 
363   std::unique_ptr<HloInstruction> cloned_instruction =
364       HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
365                                    rhs_instr.get());
366   auto result = Evaluate(cloned_instruction.get());
367 
368   return result;
369 }
370 
EvaluateElementwiseTernaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs,const Literal & ehs)371 StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
372     HloOpcode opcode, const Literal& lhs, const Literal& rhs,
373     const Literal& ehs) {
374   std::unique_ptr<HloInstruction> lhs_instr =
375       HloInstruction::CreateConstant(lhs.Clone());
376   std::unique_ptr<HloInstruction> rhs_instr =
377       HloInstruction::CreateConstant(rhs.Clone());
378   std::unique_ptr<HloInstruction> ehs_instr =
379       HloInstruction::CreateConstant(ehs.Clone());
380   TF_ASSIGN_OR_RETURN(auto output_shape,
381                       ShapeInference::InferTernaryOpShape(
382                           opcode, lhs.shape(), rhs.shape(), ehs.shape()));
383   std::unique_ptr<HloInstruction> cloned_instruction =
384       HloInstruction::CreateTernary(output_shape, opcode, lhs_instr.get(),
385                                     rhs_instr.get(), ehs_instr.get());
386   return Evaluate(cloned_instruction.get());
387 }
388 
EvaluateElementwiseCompareOp(ComparisonDirection direction,const Literal & lhs,const Literal & rhs)389 StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
390     ComparisonDirection direction, const Literal& lhs, const Literal& rhs) {
391   std::unique_ptr<HloInstruction> lhs_instr =
392       HloInstruction::CreateConstant(lhs.Clone());
393   std::unique_ptr<HloInstruction> rhs_instr =
394       HloInstruction::CreateConstant(rhs.Clone());
395 
396   std::unique_ptr<HloInstruction> cloned_instruction =
397       HloInstruction::CreateCompare(
398           ShapeUtil::ChangeElementType(lhs.shape(), PRED), lhs_instr.get(),
399           rhs_instr.get(), direction);
400   auto result = Evaluate(cloned_instruction.get());
401 
402   return result;
403 }
404 
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)405 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
406     HloOpcode opcode, const Literal& operand) {
407   std::unique_ptr<HloInstruction> operand_instr =
408       HloInstruction::CreateConstant(operand.Clone());
409 
410   std::unique_ptr<HloInstruction> cloned_instruction =
411       HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
412   auto result = Evaluate(cloned_instruction.get());
413 
414   return result;
415 }
416 
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)417 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
418     const DotDimensionNumbers& dim_numbers,
419     const PrecisionConfig& precision_config, const Literal& lhs,
420     const Literal& rhs) {
421   std::unique_ptr<HloInstruction> lhs_instr =
422       HloInstruction::CreateConstant(lhs.Clone());
423   std::unique_ptr<HloInstruction> rhs_instr =
424       HloInstruction::CreateConstant(rhs.Clone());
425 
426   TF_ASSIGN_OR_RETURN(Shape dot_shape,
427                       ShapeInference::InferDotOpShape(
428                           lhs.shape(), rhs.shape(), dim_numbers,
429                           /*preferred_element_type=*/absl::nullopt));
430 
431   std::unique_ptr<HloInstruction> cloned_instruction =
432       HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
433                                 dim_numbers, precision_config);
434   return Evaluate(cloned_instruction.get());
435 }
436 
HandleBitcast(HloInstruction * bitcast)437 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
438   const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
439   Literal result(bitcast->shape());
440   TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
441   memcpy(result.untyped_data(), operand_literal.untyped_data(),
442          operand_literal.size_bytes());
443   evaluated_[bitcast] = std::move(result);
444   return Status::OK();
445 }
446 
HandleGetDimensionSize(HloInstruction * get_dimension_size)447 Status HloEvaluator::HandleGetDimensionSize(
448     HloInstruction* get_dimension_size) {
449   HloInstruction* operand = get_dimension_size->mutable_operand(0);
450   int64_t dim = get_dimension_size->dimension();
451   if (dynamic_dimension_inference_ == nullptr) {
452     return InvalidArgument(
453         "Evaluator cannot evaluate get_dimension_size without "
454         "set_dynamic_dimension_inference.");
455   }
456   HloInstruction* dynamic_size =
457       dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
458   if (dynamic_size != nullptr) {
459     evaluated_[get_dimension_size] =
460         GetEvaluatedLiteralFor(dynamic_size).Clone();
461     return Status::OK();
462   }
463 
464   const Shape& shape = get_dimension_size->operand(0)->shape();
465   Literal output(ShapeUtil::MakeShape(S32, {}));
466   output.PopulateWithValue(
467       static_cast<int32>(shape.dimensions(get_dimension_size->dimension())));
468   evaluated_[get_dimension_size] = std::move(output);
469   return Status::OK();
470 }
471 
HandleSetDimensionSize(HloInstruction * set_dimension_size)472 Status HloEvaluator::HandleSetDimensionSize(
473     HloInstruction* set_dimension_size) {
474   const Literal& operand_literal =
475       GetEvaluatedLiteralFor(set_dimension_size->operand(0));
476   Literal result(set_dimension_size->shape());
477   memcpy(result.untyped_data(), operand_literal.untyped_data(),
478          operand_literal.size_bytes());
479   const Literal& size_literal =
480       GetEvaluatedLiteralFor(set_dimension_size->operand(1));
481   result.SetDynamicSize(set_dimension_size->dimension(),
482                         size_literal.Get<int32>({}));
483   evaluated_[set_dimension_size] = std::move(result);
484   return Status::OK();
485 }
486 
HandleParameter(HloInstruction * parameter)487 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
488   // Nothing to do other than sanity checks. Parameters' values are stored in
489   // arg_literals_.
490   CHECK_LT(parameter->parameter_number(), arg_literals_.size());
491 
492 #ifndef NDEBUG
493   const Literal* input_literal = arg_literals_[parameter->parameter_number()];
494   VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
495   DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
496                                                    input_literal->shape()))
497       << "parameter shape is: "
498       << ShapeUtil::HumanStringWithLayout(parameter->shape())
499       << ", but input literal shape is: "
500       << ShapeUtil::HumanStringWithLayout(input_literal->shape());
501 #endif
502 
503   return Status::OK();
504 }
505 
HandleConstant(HloInstruction *)506 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
507 
HandleReshape(HloInstruction * reshape)508 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
509   TF_ASSIGN_OR_RETURN(
510       evaluated_[reshape],
511       GetEvaluatedLiteralFor(reshape->operand(0))
512           .Reshape(AsInt64Slice(reshape->shape().dimensions())));
513   return Status::OK();
514 }
515 
HandleTranspose(HloInstruction * transpose)516 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
517   evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
518                               .Transpose(transpose->dimensions());
519   return Status::OK();
520 }
521 
HandleConcatenate(HloInstruction * concatenate)522 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
523   absl::Span<HloInstruction* const> operands(concatenate->operands());
524   // The result concatenate dimension is going to be the sum of all
525   // concatenate dimensions of the operands taking part of the operation.
526   const Shape& reference_shape = operands[0]->shape();
527   CHECK(reference_shape.IsArray());
528   const int64_t rank = reference_shape.rank();
529   const int64_t concat_dim = concatenate->dimensions()[0];
530   CHECK_GE(concat_dim, 0);
531   CHECK_LT(concat_dim, rank);
532 
533   DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
534                                     reference_shape.dimensions().end());
535 
536   for (int64_t i = 1; i < operands.size(); ++i) {
537     const Shape& operand_shape = operands[i]->shape();
538     CHECK(operand_shape.IsArray());
539     // Accumulate the concat dimension from all tensors taking part to the
540     // operation.
541     concat_dimensions[concat_dim] +=
542         ShapeUtil::GetDimension(operand_shape, concat_dim);
543   }
544 
545   auto result_literal = LiteralUtil::CreateFromDimensions(
546       reference_shape.element_type(), concat_dimensions);
547   DimensionVector source_indices(rank, 0);
548   DimensionVector dest_indices(concat_dimensions.size(), 0);
549 
550   for (auto operand : operands) {
551     const Shape& operand_shape = operand->shape();
552     TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
553         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
554         AsInt64Slice(operand_shape.dimensions())));
555     dest_indices[concat_dim] +=
556         ShapeUtil::GetDimension(operand_shape, concat_dim);
557   }
558 
559   evaluated_[concatenate] = std::move(result_literal);
560   return Status::OK();
561 }
562 
HandleIsFinite(HloInstruction * is_finite)563 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
564   auto operand = is_finite->operand(0);
565   auto elem_ty = operand->shape().element_type();
566   switch (elem_ty) {
567     case PRED:
568     case TUPLE:
569     case OPAQUE_TYPE:
570     case TOKEN:
571     case S8:
572     case S16:
573     case S32:
574     case S64:
575     case U8:
576     case U16:
577     case U32:
578     case U64:
579     case C64:
580     case C128:
581     // Explicitly enumerate all types in this switch so that when we add a new
582     // type, we'll get a compile error here.
583     case PRIMITIVE_TYPE_INVALID:
584     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
585     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
586       return InvalidArgument(
587           "expected element type in shape to be floating point, but "
588           "got: %s",
589           PrimitiveType_Name(elem_ty));
590 
591     case F16: {
592       auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
593           is_finite,
594           [](Eigen::half elem_operand) {
595             return std::isfinite(static_cast<float>(elem_operand));
596           },
597           GetEvaluatedLiteralFor(operand));
598       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
599       break;
600     }
601     case BF16: {
602       auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
603           is_finite,
604           [](bfloat16 elem_operand) {
605             return std::isfinite(static_cast<float>(elem_operand));
606           },
607           GetEvaluatedLiteralFor(operand));
608       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
609       break;
610     }
611     case F32: {
612       auto result_or = ElementWiseUnaryOpImpl<bool, float>(
613           is_finite,
614           [](float elem_operand) { return std::isfinite(elem_operand); },
615           GetEvaluatedLiteralFor(operand));
616       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
617       break;
618     }
619     case F64: {
620       auto result_or = ElementWiseUnaryOpImpl<bool, double>(
621           is_finite,
622           [](double elem_operand) { return std::isfinite(elem_operand); },
623           GetEvaluatedLiteralFor(operand));
624       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
625       break;
626     }
627   }
628 
629   return Status::OK();
630 }
631 
HandleReal(HloInstruction * real)632 Status HloEvaluator::HandleReal(HloInstruction* real) {
633   auto operand = real->operand(0);
634   switch (operand->shape().element_type()) {
635     case BF16: {
636       auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
637           real, [](bfloat16 elem_operand) { return elem_operand; },
638           GetEvaluatedLiteralFor(operand));
639       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
640       break;
641     }
642     case C64: {
643       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
644           real, [](complex64 elem_operand) { return std::real(elem_operand); },
645           GetEvaluatedLiteralFor(operand));
646       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
647       break;
648     }
649     case C128: {
650       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
651           real, [](complex128 elem_operand) { return std::real(elem_operand); },
652           GetEvaluatedLiteralFor(operand));
653       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
654       break;
655     }
656     case F16: {
657       auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
658           real, [](Eigen::half elem_operand) { return elem_operand; },
659           GetEvaluatedLiteralFor(operand));
660       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
661       break;
662     }
663     case F32: {
664       auto result_or = ElementWiseUnaryOpImpl<float, float>(
665           real, [](float elem_operand) { return elem_operand; },
666           GetEvaluatedLiteralFor(operand));
667       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
668       break;
669     }
670     case F64: {
671       auto result_or = ElementWiseUnaryOpImpl<double, double>(
672           real, [](double elem_operand) { return elem_operand; },
673           GetEvaluatedLiteralFor(operand));
674       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
675       break;
676     }
677     default:
678       LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
679                  << PrimitiveType_Name(operand->shape().element_type());
680   }
681 
682   return Status::OK();
683 }
684 
HandleImag(HloInstruction * imag)685 Status HloEvaluator::HandleImag(HloInstruction* imag) {
686   auto operand = imag->operand(0);
687   switch (operand->shape().element_type()) {
688     case C64: {
689       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
690           imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
691           GetEvaluatedLiteralFor(imag->operand(0)));
692 
693       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
694       break;
695     }
696     case C128: {
697       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
698           imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
699           GetEvaluatedLiteralFor(imag->operand(0)));
700 
701       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
702       break;
703     }
704     default:
705       LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
706                  << PrimitiveType_Name(operand->shape().element_type());
707   }
708 
709   return Status::OK();
710 }
711 
HandleComplex(HloInstruction * complex)712 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
713   const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
714   const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
715   TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
716 
717   Literal result(complex->shape());
718   switch (complex->shape().element_type()) {
719     case C64: {
720       TF_RETURN_IF_ERROR(
721           result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
722             return std::complex<float>(real.Get<float>(multi_index),
723                                        imag.Get<float>(multi_index));
724           }));
725       break;
726     }
727     case C128: {
728       TF_RETURN_IF_ERROR(
729           result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
730             return std::complex<double>(real.Get<double>(multi_index),
731                                         imag.Get<double>(multi_index));
732           }));
733       break;
734     }
735     default:
736       LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
737                  << PrimitiveType_Name(complex->shape().element_type());
738   }
739 
740   evaluated_[complex] = std::move(result);
741   return Status::OK();
742 }
743 
HandleCompare(HloInstruction * compare)744 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
745   ComparisonDirection direction = compare->comparison_direction();
746   auto lhs = compare->operand(0);
747   auto rhs = compare->operand(1);
748   DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
749          ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
750 
751   TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
752 
753   const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
754   const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
755 
756   // Note here we switch on the operand's type.
757   switch (lhs->shape().element_type()) {
758     case PRED: {
759       TF_ASSIGN_OR_RETURN(
760           evaluated_[compare],
761           Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
762     } break;
763     case U8: {
764       TF_ASSIGN_OR_RETURN(evaluated_[compare],
765                           Compare<uint8>(compare->shape(), direction,
766                                          lhs_literal, rhs_literal));
767     } break;
768     case U16: {
769       TF_ASSIGN_OR_RETURN(evaluated_[compare],
770                           Compare<uint16>(compare->shape(), direction,
771                                           lhs_literal, rhs_literal));
772     } break;
773     case U32: {
774       TF_ASSIGN_OR_RETURN(evaluated_[compare],
775                           Compare<uint32>(compare->shape(), direction,
776                                           lhs_literal, rhs_literal));
777     } break;
778     case U64: {
779       TF_ASSIGN_OR_RETURN(evaluated_[compare],
780                           Compare<uint64>(compare->shape(), direction,
781                                           lhs_literal, rhs_literal));
782     } break;
783     case S8: {
784       TF_ASSIGN_OR_RETURN(
785           evaluated_[compare],
786           Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
787     } break;
788     case S16: {
789       TF_ASSIGN_OR_RETURN(evaluated_[compare],
790                           Compare<int16>(compare->shape(), direction,
791                                          lhs_literal, rhs_literal));
792     } break;
793     case S32: {
794       TF_ASSIGN_OR_RETURN(evaluated_[compare],
795                           Compare<int32>(compare->shape(), direction,
796                                          lhs_literal, rhs_literal));
797     } break;
798     case S64: {
799       TF_ASSIGN_OR_RETURN(evaluated_[compare],
800                           Compare<int64>(compare->shape(), direction,
801                                          lhs_literal, rhs_literal));
802     } break;
803     case F16: {
804       TF_ASSIGN_OR_RETURN(
805           evaluated_[compare],
806           Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
807     } break;
808     case BF16: {
809       TF_ASSIGN_OR_RETURN(evaluated_[compare],
810                           Compare<bfloat16>(compare->shape(), direction,
811                                             lhs_literal, rhs_literal));
812     } break;
813     case F32: {
814       TF_ASSIGN_OR_RETURN(evaluated_[compare],
815                           Compare<float>(compare->shape(), direction,
816                                          lhs_literal, rhs_literal));
817     } break;
818     case F64: {
819       TF_ASSIGN_OR_RETURN(evaluated_[compare],
820                           Compare<double>(compare->shape(), direction,
821                                           lhs_literal, rhs_literal));
822     } break;
823     case C64: {
824       TF_ASSIGN_OR_RETURN(evaluated_[compare],
825                           Compare<complex64>(compare->shape(), direction,
826                                              lhs_literal, rhs_literal));
827     } break;
828     case C128: {
829       TF_ASSIGN_OR_RETURN(evaluated_[compare],
830                           Compare<complex128>(compare->shape(), direction,
831                                               lhs_literal, rhs_literal));
832     } break;
833     default:
834       LOG(FATAL) << "HandleCompare: unknown primitive type: "
835                  << PrimitiveType_Name(lhs->shape().element_type());
836   }
837 
838   return Status::OK();
839 }
840 
HandleTuple(HloInstruction * tuple)841 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
842   std::vector<const Literal*> operand_literals;
843   for (auto operand : tuple->operands()) {
844     operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
845   }
846 
847   evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
848   return Status::OK();
849 }
850 
851 namespace {
852 
853 // These helper templates convert the data type and are intended to be used only
854 // within the DFT implementation below. The special case is IRFFT, where the
855 // specialization drops imaginary parts of complex values and returns real
856 // numbers.
857 template <typename ToType, typename FromType>
858 struct TypeConverter {
GetAsxla::__anon5462a6e82011::TypeConverter859   static inline ToType GetAs(FromType value) {
860     return static_cast<ToType>(value);
861   }
862 };
863 
864 template <typename FromType>
865 struct TypeConverter<float, FromType> {
GetAsxla::__anon5462a6e82011::TypeConverter866   static inline float GetAs(FromType value) {
867     return static_cast<float>(value.real());
868   }
869 };
870 
871 // This class implements the discrete Fourier transform. All transform types
872 // (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary rank and
873 // length of each dimension of the transform, and arbitrary layouts of the input
874 // and output literals. The class template parameter must be a complex type, and
875 // all internal calculations will be performed using this type.
876 //
877 // The input literal provides input data, which must be complex64 for FFT, IFFT,
878 // IRFFT transforms and float for RFFT. The transform is computed over the
879 // innermost dimensions of the input, thus the rank of the input data must be
880 // same as fft_rank or larger. The input is expected to provide Ni values along
881 // each transform axis with one exception: for IRFFT, only (N0 / 2) + 1 values
882 // are needed along the X axis (the innermost index). To increase flexibility,
883 // this implementation can handle mismatches between the input size and
884 // transform lengths by either dropping extra input values or using zeroes in
885 // place of missing input values as necessary. If the input data has rank higher
886 // than the transform, the transform is applied for each valid combination of
887 // the higher-ranking indices.
888 //
889 // The output contains complex64 values for FFT, IFFT, RFFT, and float values
890 // for IRFFT. The rank of the output as well as the sizes of the dimensions
891 // above the rank of the transform must match those of the input. Sizes of the
892 // output's "fft_rank" innermost dimensions are expected to match the length of
893 // the transform along respective axes with one exception: for RFFT, the output
894 // is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
895 // length(s) mismatch, the FFT output is trimmed to fit into the provided output
896 // shape, or the output is padded with zero values appropriately.
897 //
898 // For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
899 // input array will perform two transforms over the [][15][17] data in the sub
900 // arrays [0][][] and [1][][], dropping the values along axis X and padding axis
901 // Y with zeroes to create 16x16 working sets, and generating
902 // complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
903 // complex64[64][16][9] input array will use all input values and will produce
904 // float[64][16][16] output.
905 //
906 // The implementation of the 1D transform for lengths, that are powers of 2, is
907 // the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform
908 // lengths, a straightforward, but slow, loop nest is used. The transforms of
909 // higher ranks apply sets of 1D transforms along each axis. For example, the 2D
910 // transform is computed by applying 1D transforms to each column followed by
911 // applying 1D transforms to each row.
912 //
913 // In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
914 // time, where Ni is the length of the transform's i-th dimension. However, for
915 // dimension lengths, which are powers of 2, the run time along these dimensions
916 // is reduced to log(Ni) in the summation, giving the runtime of
917 // O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case.
918 //
919 template <typename ComplexType>
920 class FftTransform {
921  public:
FftTransform(HloInstruction * fft)922   explicit FftTransform(HloInstruction* fft)
923       : fft_type_(fft->fft_type()),
924         fft_rank_(fft->fft_length().size()),
925         fft_lengths_(fft->fft_length()) {
926     // Make fft_lengths_[0] the minormost dimension.
927     absl::c_reverse(fft_lengths_);
928   }
929 
ComputeFft(HloInstruction * fft,const Literal & input_literal,Literal * output_literal)930   Status ComputeFft(HloInstruction* fft, const Literal& input_literal,
931                     Literal* output_literal) {
932     const Shape& input_shape = input_literal.shape();
933     const Shape& output_shape = fft->shape();
934 
935     TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape));
936 
937     const auto fft_strides = ComputeStrides(fft_lengths_);
938 
939     // Working set size.
940     const int64_t fft_size = fft_strides[fft_rank_];
941 
942     if (fft_size > 0) {
943       // Linearized working data set.
944       std::vector<ComplexType> data(fft_size);
945 
946       // Temporary buffer allocated once and used in 1D sweeps. For dimension
947       // length values that are powers of 2, the buffer should be twice as large
948       // to simultaneously hold input and output in Fft1D() above.
949       int64_t buffer_size = 0;
950       for (auto len : fft_lengths_) {
951         int64_t size = IsPowerOfTwo(static_cast<uint64>(len)) ? len * 2 : len;
952         buffer_size = std::max(buffer_size, size);
953       }
954       std::vector<ComplexType> buffer(buffer_size);
955 
956       // Sizes of each axis of input and output literals.
957       const auto input_lengths = GetDimensionLengths(input_literal);
958       const auto output_lengths = GetDimensionLengths(*output_literal);
959 
960       // Strides for generating linearized indices into multidimensional arrays.
961       const auto input_strides = ComputeStrides(input_lengths, input_literal);
962       const auto output_strides =
963           ComputeStrides(output_lengths, *output_literal);
964 
965       // Visit all elements in the dimensions with ranks above the FFT rank. For
966       // each such element invoke the transform. Use separate indices for the
967       // input and the output to allow different layouts.
968       auto base_case = [&](int64_t axis, int64_t output_index,
969                            int64_t input_index, bool within_src_bounds) {
970         if (axis == fft_rank_ - 1) {
971           // Base case: copy the data from the input literal, apply the
972           // transform, and copy the result to the output literal.
973           CHECK(within_src_bounds);
974           bool input_is_zero = CopyDataFromInput(
975               input_literal, input_index, fft_size, fft_lengths_, fft_strides,
976               input_lengths, input_strides, absl::MakeSpan(data));
977           if (!input_is_zero) {
978             // Make 1D sweeps along each transform axis.
979             Sweep(fft_lengths_, fft_strides, absl::MakeSpan(data),
980                   absl::MakeSpan(buffer));
981           }
982           CopyDataToOutput(absl::MakeSpan(data), output_index, fft_lengths_,
983                            fft_strides, output_lengths, output_strides,
984                            output_literal);
985           return true;
986         }
987         return false;
988       };
989       GenerateIndices(output_lengths, output_strides, input_lengths,
990                       input_strides, input_shape.rank(), 0, 0, base_case);
991     }
992 
993     return Status::OK();
994   }
995 
996  private:
997   // Common code used by 1D implementations, which copies data from the input to
998   // the contiguous buffer. Returns true if all copied values are zero.
GatherToBuffer(absl::Span<ComplexType> data,int64_t length,int64_t start,int64_t stride,bool expand_input,absl::Span<ComplexType> buffer)999   static bool GatherToBuffer(absl::Span<ComplexType> data, int64_t length,
1000                              int64_t start, int64_t stride, bool expand_input,
1001                              absl::Span<ComplexType> buffer) {
1002     CHECK_GE(buffer.size(), length);
1003     bool input_is_zero = true;
1004     const int64_t ub = expand_input ? length / 2 + 1 : length;
1005     CHECK_GE(data.size(), start + (ub - 1) * stride);
1006     for (int64_t k = 0; k < ub; k++) {
1007       ComplexType value = data[start + k * stride];
1008       input_is_zero &= value == ComplexType(0.0, 0.0);
1009       buffer[k] = value;
1010       if (expand_input) {
1011         // Use conjugates of the values at indices [1 ... (ub - 2)] when the
1012         // length is even and at indices [1 ... (ub - 1)] when the length is odd
1013         // to calculate missing values at indices [(length - 1) ... ub].
1014         if (k > 0 && k < (length - ub + 1)) {
1015           buffer[length - k] = std::conj(value);
1016         }
1017       }
1018     }
1019     return input_is_zero;
1020   }
1021 
1022   // Returns (conjugated, if 'inverse' is true) k-th twiddle for the given
1023   // length.
Twiddle(int64_t k,int64_t length,bool inverse)1024   static inline ComplexType Twiddle(int64_t k, int64_t length, bool inverse) {
1025     auto coeff = std::exp(ComplexType(0.0, -2.0 * M_PI * k / length));
1026     return inverse ? std::conj(coeff) : coeff;
1027   }
1028 
1029   // Straightforward implementation of 1D DFT transform of arbitrary length.
1030   // Uses passed-in start index and stride to gather inputs from the data vector
1031   // into the preallocated buffer, computes the result, and writes it back to
1032   // the same locations in the data vector. Runs in O(length^2) time.
1033   //
1034   // Parameters contract_output and expand_input are used to avoid unnecessary
1035   // calculations. When contract_output is set to true, then only (length / 2) +
1036   // 1 output values are computed. When expand_input is set to true, then
1037   // (length / 2) + 1 values from the data set are used to re-create the full
1038   // set of size 'length', on which the transform is then performed.
1039   //
NaiveDft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1040   static void NaiveDft1D(int64_t length, int64_t start, int64_t stride,
1041                          bool inverse, bool contract_output, bool expand_input,
1042                          absl::Span<ComplexType> data,
1043                          absl::Span<ComplexType> buffer) {
1044     const bool input_is_zero =
1045         GatherToBuffer(data, length, start, stride, expand_input, buffer);
1046 
1047     if (!input_is_zero) {
1048       const int64_t ub = contract_output ? length / 2 + 1 : length;
1049       for (int64_t k = 0; k < ub; k++) {
1050         ComplexType value = ComplexType(0.0, 0.0);
1051         for (int n = 0; n < length; n++) {
1052           value += buffer[n] * Twiddle(n * k, length, inverse);
1053         }
1054         data[start + k * stride] =
1055             inverse ? value / ComplexType(length, 0.0) : value;
1056       }
1057     }
1058   }
1059 
1060   // Non-recursive implementation of the Cooley-Tukey radix-2 decimation in
1061   // time. Performs 1D FFT transform for the lengths, which are powers of 2.
1062   // Runs in O(length * log(length)) time. Uses the same parameters as the naive
1063   // implementation above, except that the preallocated buffer must be at least
1064   // twice as big as the length of the transform, because the buffer is used to
1065   // hold both input and output values for each stage of the transform.
1066   //
Fft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1067   static void Fft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1068                     bool contract_output, bool expand_input,
1069                     absl::Span<ComplexType> data,
1070                     absl::Span<ComplexType> buffer) {
1071     CHECK(IsPowerOfTwo(static_cast<uint64>(length)));
1072     const bool input_is_zero =
1073         GatherToBuffer(data, length, start, stride, expand_input, buffer);
1074 
1075     if (!input_is_zero) {
1076       auto generate_twiddles = [](int64_t length, bool inverse) {
1077         std::vector<ComplexType> twiddles;
1078         // Need only half the twiddles.
1079         for (int64_t k = 0; k < length / 2; k++) {
1080           twiddles.push_back(Twiddle(k, length, inverse));
1081         }
1082         return twiddles;
1083       };
1084 
1085       // Indices into the parts of the buffer used for input and output values.
1086       int64_t in_base = length;
1087       int64_t out_base = 0;
1088 
1089       // At each stage, we "split" the input data into num_blocks, with
1090       // block_size values in each block.
1091       for (int64_t num_blocks = 1; num_blocks < length; num_blocks *= 2) {
1092         // Swap input and output parts of the buffer.
1093         std::swap(in_base, out_base);
1094         auto twiddles = generate_twiddles(num_blocks * 2, inverse);
1095         const int64_t block_size = length / num_blocks;
1096         const int64_t next_iteration_block_size = block_size / 2;
1097         for (int64_t block = 0; block < num_blocks; block++) {
1098           const int64_t in_offset = in_base + block * block_size;
1099           const int64_t out_offset =
1100               out_base + block * next_iteration_block_size;
1101           // For each (even, odd) pair of values in the block, calculate two
1102           // output values as even + twiddle * odd and even - twiddle * odd.
1103           for (int64_t pair = 0; pair < block_size / 2; pair++) {
1104             const ComplexType even = buffer[in_offset + pair];
1105             const ComplexType odd = buffer[in_offset + block_size / 2 + pair];
1106             const ComplexType twiddled_odd = twiddles[block] * odd;
1107             buffer[out_offset + pair] = even + twiddled_odd;
1108             buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
1109           }
1110         }
1111       }
1112       // Copy computed result back to data.
1113       const int64_t ub = contract_output ? length / 2 + 1 : length;
1114       for (int64_t k = 0; k < ub; k++) {
1115         ComplexType value = buffer[out_base + k];
1116         data[start + k * stride] =
1117             inverse ? value / ComplexType(length, 0.0) : value;
1118       }
1119     }
1120   }
1121 
1122   // Determine, which implementation of 1D transform to use and call it.
Dft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1123   static void Dft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1124                     bool contract_output, bool expand_input,
1125                     absl::Span<ComplexType> data,
1126                     absl::Span<ComplexType> buffer) {
1127     if (IsPowerOfTwo(static_cast<uint64>(length))) {
1128       Fft1D(length, start, stride, inverse, contract_output, expand_input, data,
1129             buffer);
1130     } else {
1131       NaiveDft1D(length, start, stride, inverse, contract_output, expand_input,
1132                  data, buffer);
1133     }
1134   }
1135 
1136   // Helper to reverse the order of dimension lengths in the passed-in literal.
GetDimensionLengths(const Literal & literal)1137   static std::vector<int64> GetDimensionLengths(const Literal& literal) {
1138     auto dimensions = literal.shape().dimensions();
1139     return std::vector<int64>(dimensions.rbegin(), dimensions.rend());
1140   }
1141 
1142   // Helper to compute strides for creating linear indices into multidimensional
1143   // data from the dimension lengths and the layout. Returns a new vector of
1144   // size lengths.size() + 1. The last element of the returned vector at index
1145   // [lengths.size()] contains the product of all dimension lengths.
ComputeStrides(const absl::Span<const int64> lengths,const Layout & layout)1146   static std::vector<int64> ComputeStrides(
1147       const absl::Span<const int64> lengths, const Layout& layout) {
1148     const int64_t num_dimensions = lengths.size();
1149 
1150     // Make sure that the layout length matches the number of dimensions.
1151     CHECK_EQ(num_dimensions, layout.minor_to_major_size());
1152 
1153     // Calculate strides using layout-specified ordering of the dimensions and
1154     // place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
1155     std::vector<int64> strides(num_dimensions + 1);
1156     int64_t stride = 1;
1157     for (int64_t i = 0; i < num_dimensions; i++) {
1158       // Reverse the ordering of the dimensions in the layout.
1159       const int64_t index = (num_dimensions - 1) - layout.minor_to_major(i);
1160       strides[index] = stride;
1161       stride *= lengths[index];
1162     }
1163     strides[num_dimensions] = stride;
1164 
1165     return strides;
1166   }
1167 
1168   // Compute strides as above using the default layout.
ComputeStrides(const absl::Span<const int64> lengths)1169   static std::vector<int64> ComputeStrides(
1170       const absl::Span<const int64> lengths) {
1171     return ComputeStrides(lengths,
1172                           LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
1173   }
1174 
1175   // Compute strides as above using the layout from the literal, if available.
ComputeStrides(const absl::Span<const int64> lengths,const Literal & literal)1176   static std::vector<int64> ComputeStrides(
1177       const absl::Span<const int64> lengths, const Literal& literal) {
1178     return literal.shape().has_layout()
1179                ? ComputeStrides(lengths, literal.shape().layout())
1180                : ComputeStrides(lengths);
1181   }
1182 
1183   // Make 1D sweeps along each transform axis.
Sweep(const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1184   void Sweep(const absl::Span<const int64> fft_lengths,
1185              const absl::Span<const int64> fft_strides,
1186              absl::Span<ComplexType> data, absl::Span<ComplexType> buffer) {
1187     const bool inverse =
1188         fft_type_ == FftType::IFFT || fft_type_ == FftType::IRFFT;
1189     const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1190     const bool output_is_truncated = fft_type_ == FftType::RFFT;
1191 
1192     // Recursively visit each column of the data along the sweep_axis. Calculate
1193     // linearized index of that column's first element and the stride, then
1194     // invoke 1D transform. For RFFT, avoid calculating unused output values:
1195     // first, compute only (length_x / 2) + 1 values along the X axis, then
1196     // limit the X coordinate to [0 ... (length / 2)] during the sweeps along
1197     // other axes. Similarly, for IRFFT sweep along higher dimensions first,
1198     // while keeping the X coordinate in the [0 ... (length / 2)] range, then
1199     // re-create negative frequencies omitted in the input and perform the
1200     // full-length transform along the X axis in the last sweep.
1201     std::function<void(int64_t, int64_t, int64_t)> sweep =
1202         [&](int64_t sweep_axis, int64_t axis, int64_t start) {
1203           if (axis < 0) {
1204             // Base case: invoke 1D transform.
1205             const int64_t length = fft_lengths[sweep_axis];
1206             const int64_t stride = fft_strides[sweep_axis];
1207             const bool expand_input = input_is_truncated && sweep_axis == 0;
1208             const bool contract_oputput =
1209                 output_is_truncated && sweep_axis == 0;
1210             Dft1D(length, start, stride, inverse, contract_oputput,
1211                   expand_input, data, buffer);
1212           } else if (axis == sweep_axis) {
1213             // Visit only the elements with coordinate 0 along the sweep axis.
1214             sweep(sweep_axis, axis - 1, start);
1215           } else {
1216             const int64_t length = fft_lengths[axis];
1217             const bool is_truncated = input_is_truncated || output_is_truncated;
1218             const int64_t ub =
1219                 is_truncated && axis == 0 ? (length / 2) + 1 : length;
1220             for (int64_t i = 0; i < ub; i++) {
1221               sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
1222             }
1223           }
1224         };
1225     if (input_is_truncated) {
1226       // Sweep along the X axis last for IRFFT.
1227       for (int64_t sweep_axis = fft_rank_ - 1; sweep_axis >= 0; sweep_axis--) {
1228         sweep(sweep_axis, fft_rank_ - 1, 0);
1229       }
1230     } else {
1231       // Sweep along the X axis first for RFFT. The order does not matter for
1232       // FFT and IFFT types; handle them here as well.
1233       for (int64_t sweep_axis = 0; sweep_axis < fft_rank_; sweep_axis++) {
1234         sweep(sweep_axis, fft_rank_ - 1, 0);
1235       }
1236     }
1237   }
1238 
1239   // This template generates two linearized indices, which can be used to access
1240   // multidimensional arrays. It uses a recursive function, which passes the
1241   // indices to the user-supplied callback function. The destination index is
1242   // always within dst_lengths[] bounds. The boolean parameter within_src_bounds
1243   // indicates whether the source index is within src_lengths[] bounds.
1244   //
1245   // The value returned from the callback function controls the recursion depth.
1246   // Returning true indicates that the base case had been hit and the recursion
1247   // stops. Otherwise, the recursion proceeds along the next less-major axis.
1248   //
1249   // For example, the base case when the axis value becomes negative invokes the
1250   // callback function for each possible index within dst_lengths[] bounds. The
1251   // base case when the axis value is equal to zero limits the indices to point
1252   // only to first elements along the minor-most dimension, allowing the
1253   // callback function to handle all values along the X axis.
1254   //
1255   template <typename BaseFn>
GenerateIndices(const absl::Span<const int64> dst_lengths,const absl::Span<const int64> dst_strides,const absl::Span<const int64> src_lengths,const absl::Span<const int64> src_strides,int64_t rank,int64_t dst_start,int64_t src_start,BaseFn && base)1256   static void GenerateIndices(const absl::Span<const int64> dst_lengths,
1257                               const absl::Span<const int64> dst_strides,
1258                               const absl::Span<const int64> src_lengths,
1259                               const absl::Span<const int64> src_strides,
1260                               int64_t rank, int64_t dst_start,
1261                               int64_t src_start, BaseFn&& base) {
1262     CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
1263     CHECK_GE(dst_lengths.size(), rank);
1264     CHECK_EQ(src_lengths.size() + 1, src_strides.size());
1265     CHECK_GE(src_lengths.size(), rank);
1266 
1267     std::function<void(int64_t, int64_t, int64_t, bool)> generate =
1268         [&](int64_t axis, int64_t dst_index, int64_t src_index,
1269             bool within_src_bounds) {
1270           if (!base(axis, dst_index, src_index, within_src_bounds)) {
1271             for (int64_t i = 0; i < dst_lengths[axis]; i++) {
1272               // Because the loop goes over dst_lengths[], the source index may
1273               // be out of src_lengths[] bounds. In this case, within_src_bounds
1274               // is false.
1275               within_src_bounds &= i < src_lengths[axis];
1276               generate(axis - 1, dst_index, src_index, within_src_bounds);
1277               dst_index += dst_strides[axis];
1278               src_index += src_strides[axis];
1279             }
1280           }
1281         };
1282     generate(rank - 1, dst_start, src_start, true);
1283   }
1284 
1285   // Copies the input data from a literal to a pre-allocated vector. The sizes
1286   // of the input and the transform do not need to match. For each axis of the
1287   // transform, any extra input values beyond the transform length are ignored.
1288   // Conversely, if the input does not contain enough elements along any axis,
1289   // the data is padded with zeroes.
1290   //
1291   // For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
1292   // where length_x is the size of the full transform along the X axis.
1293   //
1294   // The input literal may have a rank higher than the rank of the transform.
1295   // Passed-in input_index value points to the first element of the input
1296   // literal to be copied.
1297   //
1298   // Returns true if all values in the work data set are zeroes.
1299   //
1300   template <typename InputType>
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<ComplexType> data)1301   bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
1302                          int64_t fft_size,
1303                          const absl::Span<const int64> fft_lengths,
1304                          const absl::Span<const int64> fft_strides,
1305                          const absl::Span<const int64> input_lengths,
1306                          const absl::Span<const int64> input_strides,
1307                          absl::Span<ComplexType> data) {
1308     CHECK_GE(data.size(), fft_size);
1309 
1310     const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1311 
1312     // Recursively visit each transform dimension to copy input values to the
1313     // working data set. The base case handles inputs along the X axis.
1314     bool input_is_zero = true;
1315     const InputType* input_data = input_literal.data<InputType>().data();
1316     auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
1317                          bool within_src_bounds) {
1318       if (axis == 0) {
1319         // For IRFFT, the negative frequencies are only needed for the sweep
1320         // along the X axis, which is performed last. Leave this part of the
1321         // working set uninitialized until then.
1322         const int64_t length = fft_lengths[axis];
1323         const int64_t ub = input_is_truncated ? (length / 2) + 1 : length;
1324         for (int64_t i = 0; i < ub; i++) {
1325           ComplexType value = ComplexType(0);
1326           // Read input value only if the index is within bounds.
1327           if (within_src_bounds && i < input_lengths[axis]) {
1328             value = TypeConverter<ComplexType, InputType>::GetAs(
1329                 input_data[src_index + i * input_strides[axis]]);
1330             input_is_zero &= value == ComplexType(0.0, 0.0);
1331           }
1332           data[dst_index + i * fft_strides[axis]] = value;
1333         }
1334         return true;
1335       }
1336       return false;
1337     };
1338     GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
1339                     fft_rank_, 0, input_start, base_case);
1340     return input_is_zero;
1341   }
1342 
1343   // Copies the result of the transform to the literal output. The sizes of the
1344   // transform and output must match.
1345   //
1346   // For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
1347   // the size of the full transform along the X axis (the most minor dimension).
1348   //
1349   // The output literal may have a rank higher than the rank of the transform.
1350   // Passed-in output_index value points to the first element of the output
1351   // literal to be filled in.
1352   //
1353   template <typename OutputType>
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1354   void CopyDataToOutput(const absl::Span<ComplexType> data,
1355                         int64_t output_start,
1356                         const absl::Span<const int64> fft_lengths,
1357                         const absl::Span<const int64> fft_strides,
1358                         const absl::Span<const int64> output_lengths,
1359                         const absl::Span<const int64> output_strides,
1360                         Literal* output_literal) {
1361     const bool output_is_truncated = fft_type_ == FftType::RFFT;
1362 
1363     // Base case for recursive copy of the results to the output. The code
1364     // avoids making a recursive call for each output element by handling axis 0
1365     // in the loop (as opposed to making "axis < 0" to be the base case).
1366     OutputType* output_data = output_literal->data<OutputType>().data();
1367     auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
1368                          bool within_src_bounds) {
1369       if (axis == 0) {
1370         // Drop negative frequencies for RFFT.
1371         const int64_t length = fft_lengths[axis];
1372         const int64_t ub = output_is_truncated ? (length / 2) + 1 : length;
1373         for (int64_t i = 0; i < output_lengths[axis]; i++) {
1374           OutputType value = OutputType(0);
1375           // Read data only if the index is within bounds.
1376           if (within_src_bounds && i < ub) {
1377             value = TypeConverter<OutputType, ComplexType>::GetAs(
1378                 data[src_index + i * fft_strides[axis]]);
1379           }
1380           output_data[dst_index + i * output_strides[axis]] = value;
1381         }
1382         return true;
1383       }
1384       return false;
1385     };
1386     GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
1387                     fft_rank_, output_start, 0, base_case);
1388   }
1389 
1390   // Determine the type to use with the CopyDataFromInput<> template above.
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<ComplexType> data)1391   bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
1392                          int64_t fft_size,
1393                          const absl::Span<const int64> fft_lengths,
1394                          const absl::Span<const int64> fft_strides,
1395                          const absl::Span<const int64> input_lengths,
1396                          const absl::Span<const int64> input_strides,
1397                          absl::Span<ComplexType> data) {
1398     const bool input_is_float = fft_type_ == FftType::RFFT;
1399     if (input_is_float) {
1400       return CopyDataFromInput<float>(input_literal, input_start, fft_size,
1401                                       fft_lengths, fft_strides, input_lengths,
1402                                       input_strides, data);
1403     } else {
1404       return CopyDataFromInput<complex64>(input_literal, input_start, fft_size,
1405                                           fft_lengths, fft_strides,
1406                                           input_lengths, input_strides, data);
1407     }
1408   }
1409 
1410   // Determine the type to use with the CopyDataToOutput<> template above.
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1411   void CopyDataToOutput(const absl::Span<ComplexType> data,
1412                         int64_t output_start,
1413                         const absl::Span<const int64> fft_lengths,
1414                         const absl::Span<const int64> fft_strides,
1415                         const absl::Span<const int64> output_lengths,
1416                         const absl::Span<const int64> output_strides,
1417                         Literal* output_literal) {
1418     const bool output_is_float = fft_type_ == FftType::IRFFT;
1419     if (output_is_float) {
1420       CopyDataToOutput<float>(data, output_start, fft_lengths, fft_strides,
1421                               output_lengths, output_strides, output_literal);
1422     } else {
1423       CopyDataToOutput<complex64>(data, output_start, fft_lengths, fft_strides,
1424                                   output_lengths, output_strides,
1425                                   output_literal);
1426     }
1427   }
1428 
CheckParameters(const Shape & input_shape,const Shape & output_shape)1429   Status CheckParameters(const Shape& input_shape, const Shape& output_shape) {
1430     // Check FFT parameters.
1431     if (fft_rank_ <= 0) {
1432       return InvalidArgument("Zero or negative FFT rank.");
1433     }
1434     if (*absl::c_min_element(fft_lengths_) < 0) {
1435       return InvalidArgument("Negative FFT length.");
1436     }
1437 
1438     // Check input-related values.
1439     TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
1440     if (!input_shape.IsArray()) {
1441       return Unimplemented("Only array input shapes are supported.");
1442     }
1443     auto input_elt_type = input_shape.element_type();
1444     if (fft_type_ == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
1445       return InvalidArgument("Invalid input type: %d, must be %d (float).",
1446                              input_elt_type, PrimitiveType::F32);
1447     }
1448     if (fft_type_ != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
1449       return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
1450                              input_elt_type, PrimitiveType::C64);
1451     }
1452     const int64_t input_rank = input_shape.rank();
1453     if (input_rank < fft_rank_) {
1454       return InvalidArgument("Input shape rank is smaller than FFT rank.");
1455     }
1456 
1457     // Check output-related values.
1458     TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
1459     if (!output_shape.IsArray()) {
1460       return Unimplemented("Only array output shapes are supported.");
1461     }
1462     auto output_elt_type = output_shape.element_type();
1463     if (fft_type_ == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
1464       return InvalidArgument("Invalid output type: %d, must be %d (float).",
1465                              output_elt_type, PrimitiveType::F32);
1466     }
1467     if (fft_type_ != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
1468       return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
1469                              output_elt_type, PrimitiveType::C64);
1470     }
1471     const int64_t output_rank = output_shape.rank();
1472     if (output_rank < fft_rank_) {
1473       return InvalidArgument("Output shape rank is smaller than FFT rank.");
1474     }
1475 
1476     // Consistency of input and output parameters.
1477     if (input_rank != output_rank) {
1478       return InvalidArgument(
1479           "Ranks of input shape and output shape do not match.");
1480     }
1481     for (int64_t dim = 0; dim < input_rank - fft_rank_; dim++) {
1482       if (ShapeUtil::GetDimension(input_shape, dim) !=
1483           ShapeUtil::GetDimension(output_shape, dim)) {
1484         return InvalidArgument(
1485             "Higher dimension lengths of input shape and output shape do not "
1486             "match.");
1487       }
1488     }
1489 
1490     return Status::OK();
1491   }
1492 
1493  private:
1494   const FftType fft_type_;
1495   const int64 fft_rank_;
1496   std::vector<int64> fft_lengths_;
1497 };
1498 
1499 }  // namespace
1500 
HandleFft(HloInstruction * fft)1501 Status HloEvaluator::HandleFft(HloInstruction* fft) {
1502   const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
1503   Literal output_literal = Literal::CreateFromShape(fft->shape());
1504 
1505   FftTransform<complex128> transform(fft);
1506   TF_RETURN_IF_ERROR(transform.ComputeFft(fft, input_literal, &output_literal));
1507   evaluated_[fft] = std::move(output_literal);
1508 
1509   return Status::OK();
1510 }
1511 
1512 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
1513 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)1514 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
1515     const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
1516   int64_t output_rank = output_shape.dimensions_size();
1517   std::vector<int64> index_base(output_rank, 0);
1518   std::vector<int64> index_count;
1519   index_count.reserve(output_rank);
1520   for (int64_t i = 0; i < output_rank; i++) {
1521     bool is_output_batch_dim =
1522         !absl::c_binary_search(dim_numbers.offset_dims(), i);
1523     index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
1524   }
1525 
1526   return {std::move(index_base), std::move(index_count),
1527           std::vector<int64>(output_rank, 1)};
1528 }
1529 
1530 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
1531 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64_t output_rank,absl::Span<const int64> slice_sizes,const GatherDimensionNumbers & dim_numbers)1532 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
1533     int64_t output_rank, absl::Span<const int64> slice_sizes,
1534     const GatherDimensionNumbers& dim_numbers) {
1535   std::vector<int64> index_base(output_rank, 0);
1536   std::vector<int64> index_count(output_rank, 1);
1537   int64_t slice_sizes_idx = 0;
1538   for (int64_t i = 0; i < output_rank; i++) {
1539     bool is_output_window_dim =
1540         absl::c_binary_search(dim_numbers.offset_dims(), i);
1541     if (is_output_window_dim) {
1542       while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
1543                                    slice_sizes_idx)) {
1544         slice_sizes_idx++;
1545       }
1546       index_count[i] = slice_sizes[slice_sizes_idx++];
1547     }
1548   }
1549 
1550   return {std::move(index_base), std::move(index_count),
1551           std::vector<int64>(output_rank, 1)};
1552 }
1553 
1554 // This functor computes the contribution of start_indices to an input index
1555 // corresponding to an output index.  That is, given an output index I, it picks
1556 // out the batch indices in I and uses them to look up a starting index, G, from
1557 // the start indices tensor, and expands G into the input space according to
1558 // start_index_map.
1559 class OutputBatchIndexToInputIndex {
1560  public:
1561   // The constructor does some setup work that is amortized across all
1562   // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)1563   explicit OutputBatchIndexToInputIndex(
1564       const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
1565       const Shape& output_shape, const Literal* start_indices)
1566       : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
1567     for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
1568       output_dim_is_batch_dims_.push_back(
1569           !absl::c_binary_search(dim_numbers_.offset_dims(), i));
1570     }
1571 
1572     for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
1573       int64_t index_of_input_dim_in_index_vector =
1574           std::distance(dim_numbers_.start_index_map().begin(),
1575                         absl::c_find(dim_numbers_.start_index_map(), i));
1576       if (index_of_input_dim_in_index_vector ==
1577           dim_numbers_.start_index_map_size()) {
1578         input_dim_value_to_index_vector_.push_back(-1);
1579       } else {
1580         input_dim_value_to_index_vector_.push_back(
1581             index_of_input_dim_in_index_vector);
1582       }
1583     }
1584 
1585     index_vector_index_.resize(start_indices_.shape().dimensions_size());
1586     input_index_.resize(input_shape.dimensions_size());
1587     int64_t index_vector_size =
1588         start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
1589     index_vector_.resize(index_vector_size);
1590   }
1591 
1592   // Returns the contribution of start_indices to the input index corresponding
1593   // to output_index.  See gather_inner_loop_body.
1594   //
1595   // This is conceptually  a stateless transformation from output_index to the
1596   // gather input index, but:
1597   //
1598   //  - Instead of allocating memory to represent the gather input index on
1599   //    every invocation we reuse the same storage for the result
1600   //    (input_index_), mutating it in place.
1601   //  - Instead of allocating buffers for temporary values like
1602   //    index_vector_index_ and index_vector on every invocation, we reuse the
1603   //    same storage for all invocations.
1604   //
1605   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1606   StatusOr<absl::Span<const int64>> operator()(
1607       absl::Span<const int64> output_index) {
1608     PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
1609     TF_RETURN_IF_ERROR(FetchIndexVector());
1610     PropagateIndexVectorToInputIndex();
1611     return absl::Span<const int64>(input_index_);
1612   }
1613 
1614  private:
1615   // Propagates the batch dimensions from the output index into
1616   // index_vector_index_ by mutating index_vector_index_ in place.  Does not
1617   // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
1618   // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64> output_index)1619   void PropagateOutputIndexGatherDimsToIndexVectorIndex(
1620       absl::Span<const int64> output_index) {
1621     int64_t index_vector_index_i = 0;
1622     for (int64_t i = 0, e = output_index.size(); i < e; i++) {
1623       if (!output_dim_is_batch_dims_[i]) {
1624         continue;
1625       }
1626 
1627       if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
1628         index_vector_index_i++;
1629       }
1630 
1631       index_vector_index_[index_vector_index_i++] = output_index[i];
1632     }
1633   }
1634 
1635   // Populates index_vector_ by iterating over start_indices_ according to
1636   // index_vector_index_.
FetchIndexVector()1637   Status FetchIndexVector() {
1638     int64_t index_vector_dim = dim_numbers_.index_vector_dim();
1639     for (int64_t i = 0, e = index_vector_.size(); i < e; i++) {
1640       index_vector_index_[index_vector_dim] = i;
1641       auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
1642       TF_RET_CHECK(start_index.has_value());
1643       index_vector_[i] = *start_index;
1644     }
1645     return Status::OK();
1646   }
1647 
1648   // Populates input_index_.
PropagateIndexVectorToInputIndex()1649   void PropagateIndexVectorToInputIndex() {
1650     for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
1651       if (input_dim_value_to_index_vector_[i] != -1) {
1652         input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
1653       }
1654 
1655       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1656       // remains 0, as set by the constructor.
1657     }
1658   }
1659 
1660   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1661   // the input index from the index vector.  See
1662   // PropagateIndexVectorToInputIndex.
1663   std::vector<int64> input_dim_value_to_index_vector_;
1664 
1665   // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
1666   // dimension.
1667   std::vector<bool> output_dim_is_batch_dims_;
1668 
1669   // The buffer into which we construct an index into start_indices_ to fetch
1670   // the index vector.
1671   std::vector<int64> index_vector_index_;
1672 
1673   // The index vector fetched from start_indices_.
1674   std::vector<int64> index_vector_;
1675 
1676   // The result computed by this functor.  operator() returns a Span into
1677   // this vector.
1678   std::vector<int64> input_index_;
1679 
1680   const GatherDimensionNumbers& dim_numbers_;
1681   const Literal& start_indices_;
1682 };
1683 
1684 // This functor computes the contribution of the offset indices in an output
1685 // index to an input index.  That is, given an output index I it picks out the
1686 // output offset indices in I and expands it into an index into the input shape.
1687 class OutputOffsetIndexToInputIndex {
1688  public:
1689   // The constructor does some setup work that is amortized across all
1690   // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)1691   explicit OutputOffsetIndexToInputIndex(
1692       const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
1693       const Shape& output_shape) {
1694     std::vector<int64> window_index_to_output_index;
1695     int64_t output_index_count = 0;
1696     for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
1697       if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1698         window_index_to_output_index.push_back(output_index_count++);
1699       } else {
1700         output_index_count++;
1701       }
1702     }
1703 
1704     int64_t window_dim_count = 0;
1705     for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
1706       if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1707         input_dim_value_to_output_index_.push_back(-1);
1708       } else {
1709         input_dim_value_to_output_index_.push_back(
1710             window_index_to_output_index[window_dim_count++]);
1711       }
1712     }
1713 
1714     input_index_.resize(input_shape.dimensions_size());
1715   }
1716 
1717   // Returns the contribution of the window indices to the input index
1718   // corresponding to output_index.  See gather_inner_loop_body.
1719   //
1720   // This is conceptually a stateless transformation from output_index to the
1721   // window input index, but instead of allocating memory to represent the
1722   // gather input index on every invocation we reuse the same storage for the
1723   // result (input_index_), mutating it in place.
1724   //
1725   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1726   StatusOr<absl::Span<const int64>> operator()(
1727       absl::Span<const int64> output_index) {
1728     PropagateOutputIndexWindowDimsToInputIndex(output_index);
1729     return absl::Span<const int64>(input_index_);
1730   }
1731 
1732   // Returns for a given 'input_dim' the corresponding output dimension index,
1733   // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64_t input_dim)1734   int64 input_dim_value_to_output_index(int64_t input_dim) {
1735     return input_dim_value_to_output_index_[input_dim];
1736   }
1737 
1738  private:
1739   // Propagates window dimensions from the output index to input_index_ by
1740   // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64> output_index)1741   void PropagateOutputIndexWindowDimsToInputIndex(
1742       absl::Span<const int64> output_index) {
1743     for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
1744       if (input_dim_value_to_output_index_[i] != -1) {
1745         input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
1746       }
1747 
1748       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1749       // remains 0, as set by the constructor.
1750     }
1751   }
1752 
1753   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1754   // the input index from the output index. See
1755   // PropagateOutputIndexWindowDimsToInputIndex.
1756   std::vector<int64> input_dim_value_to_output_index_;
1757 
1758   // The result computed by this functor.  operator() returns a Span into
1759   // this vector.
1760   std::vector<int64> input_index_;
1761 };
1762 
1763 // Reshapes the gather indices input to have a trailing degenerate `1` dimension
1764 // if necessary.  Hands over the ownership of the newly created literal (if
1765 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64_t index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)1766 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
1767     int64_t index_vector_dim, const Literal& start_indices,
1768     Literal* reshaped_start_indices) {
1769   if (start_indices.shape().dimensions_size() != index_vector_dim) {
1770     return std::cref(start_indices);
1771   }
1772 
1773   std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
1774                                start_indices.shape().dimensions().end());
1775   new_shape.push_back(1);
1776   TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
1777                       start_indices.Reshape(new_shape));
1778   return std::cref(*reshaped_start_indices);
1779 }
1780 
HandleGather(HloInstruction * gather)1781 Status HloEvaluator::HandleGather(HloInstruction* gather) {
1782   Literal result = Literal::CreateFromShape(gather->shape());
1783   const Shape& shape = gather->shape();
1784   const GatherDimensionNumbers& dim_numbers =
1785       gather->gather_dimension_numbers();
1786   const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
1787   Literal reshaped_start_indices;
1788   TF_ASSIGN_OR_RETURN(
1789       const Literal& start_indices,
1790       ReshapedGatherIndices(dim_numbers.index_vector_dim(),
1791                             GetEvaluatedLiteralFor(gather->operand(1)),
1792                             &reshaped_start_indices));
1793 
1794   // We iterate over the gather dimensions in the output shape in an outer loop
1795   // nest, and iterate over the window dimensions in the output shape in an
1796   // inner loop nest.
1797 
1798   ShapeUtil::IndexIterationSpace start_indices_iteration_space =
1799       IterationSpaceForOutputBatchIndices(shape, dim_numbers);
1800   ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
1801       IterationSpaceForOutputOffsetIndices(
1802           shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
1803 
1804   // Scratch buffers that hold an index in the output shape and the
1805   // corresponding index in the input shape.
1806   std::vector<int64> input_index(operand.shape().dimensions_size());
1807   std::vector<int64> output_index(gather->shape().dimensions_size());
1808   std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
1809 
1810   OutputBatchIndexToInputIndex output_batch_index_to_input_index(
1811       &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1812       /*output_shape=*/shape, &start_indices);
1813   OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
1814       gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1815       /*output_shape=*/shape);
1816 
1817   const Shape& operand_shape = operand.shape();
1818   if (ShapeUtil::IsZeroElementArray(operand_shape)) {
1819     evaluated_[gather] = std::move(result);
1820     return Status::OK();
1821   }
1822 
1823   auto gather_inner_loop_body =
1824       [&](absl::Span<const int64> output_window_index,
1825           absl::Span<const int64> input_gather_index,
1826           absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1827     TF_ASSIGN_OR_RETURN(
1828         absl::Span<const int64> input_window_index,
1829         output_offset_index_to_input_index(output_window_index));
1830     for (int i = 0, e = output_index.size(); i < e; i++) {
1831       output_index[i] = output_gather_index[i] + output_window_index[i];
1832       DCHECK_LT(output_index[i], shape.dimensions(i));
1833     }
1834     for (int i = 0, e = input_gather_index.size(); i < e; i++) {
1835       int64_t output_dim =
1836           output_offset_index_to_input_index.input_dim_value_to_output_index(i);
1837       // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
1838       // we set the iteration index to 0, so for the purpose of the following
1839       // calculations we can consider the output dimension size to be 1.
1840       int64_t output_dim_size =
1841           output_dim == -1 ? 1 : shape.dimensions(output_dim);
1842       // Clamp the gather index so that the gather region fits in the operand.
1843       // input_index_clamped[i] = clamp(input_gather_index[i], 0,
1844       //                                       operand_shape.dimensions(i) -
1845       //                                       output_dim_size);
1846       input_index_clamped[i] =
1847           std::min(operand_shape.dimensions(i) - output_dim_size,
1848                    std::max(int64{0}, input_gather_index[i]));
1849     }
1850     for (int i = 0, e = input_index.size(); i < e; i++) {
1851       input_index[i] = input_index_clamped[i] + input_window_index[i];
1852       DCHECK_GE(input_index[i], 0);
1853       DCHECK_LT(input_index[i], operand_shape.dimensions(i));
1854     }
1855     TF_RETURN_IF_ERROR(
1856         result.CopyElementFrom(operand, input_index, output_index));
1857     return true;
1858   };
1859 
1860   auto gather_outer_loop_body =
1861       [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1862     TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
1863                         output_batch_index_to_input_index(output_gather_index));
1864     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1865         shape, offset_indices_iteration_space,
1866         std::bind(gather_inner_loop_body, std::placeholders::_1,
1867                   input_gather_index, output_gather_index)));
1868     return true;
1869   };
1870 
1871   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1872       shape, start_indices_iteration_space, gather_outer_loop_body));
1873   evaluated_[gather] = std::move(result);
1874   return Status::OK();
1875 }
1876 
HandleBroadcast(HloInstruction * broadcast)1877 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
1878   const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
1879 
1880   TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
1881       << "broadcast dimensions is of size: " << broadcast->dimensions().size()
1882       << " and rank of operand_to_broadcast is: " << operand.shape().rank();
1883   // Checks that operand's dimensions are the same as the broadcast's
1884   // dimensions along the dimensions to be broadcasted.
1885   for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
1886     auto operand_dim_size = operand.shape().dimensions(i);
1887     auto broadcast_dim_size =
1888         broadcast->shape().dimensions(broadcast->dimensions(i));
1889     TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
1890         "Operand dimension %d is broadcast to output dimension %d, but the "
1891         "sizes of these two dims do not match (%d vs %d): %s",
1892         i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
1893         broadcast->ToString());
1894   }
1895 
1896   TF_ASSIGN_OR_RETURN(
1897       evaluated_[broadcast],
1898       operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
1899 
1900   return Status::OK();
1901 }
1902 
HandleAfterAll(HloInstruction * after_all)1903 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
1904   evaluated_[after_all] = LiteralUtil::CreateToken();
1905   return Status::OK();
1906 }
1907 
HandleAddDependency(HloInstruction * add_dependency)1908 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
1909   // AddDedendency just forwards its zero-th operand.
1910   evaluated_[add_dependency] =
1911       GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
1912   return Status::OK();
1913 }
1914 
HandleGetTupleElement(HloInstruction * get_tuple_element)1915 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
1916   const auto result_shape = get_tuple_element->shape();
1917   const int64_t index = get_tuple_element->tuple_index();
1918 
1919   auto operand = get_tuple_element->operand(0);
1920   TF_ASSIGN_OR_RETURN(
1921       auto inferred_return_shape,
1922       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
1923   TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1924       << "return shape set to: " << ShapeUtil::HumanString(result_shape)
1925       << " but is inferred to be: "
1926       << ShapeUtil::HumanString(inferred_return_shape);
1927 
1928   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1929 
1930   evaluated_[get_tuple_element] =
1931       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
1932   return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
1933                                                 /*dest_shape_index=*/{},
1934                                                 /*src_shape_index=*/{index});
1935 }
1936 
HandleCopy(HloInstruction * copy)1937 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
1938   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
1939   evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
1940   return Status::OK();
1941 }
1942 
HandleCopyStart(HloInstruction * copy_start)1943 Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
1944   if (copy_start->user_count() != 1 ||
1945       copy_start->users().at(0)->opcode() != HloOpcode::kCopyDone) {
1946     return tensorflow::errors::FailedPrecondition(
1947         "Cannot evaluate a kCopyStart that doesn't have a single kCopyDone "
1948         "user.");
1949   }
1950 
1951   // The context in index {2} is undefined, but since we can't represent
1952   // undefined values using a Literal, we just use 0. This should be safe though
1953   // since we ensure that the only user of a kCopyStart is a kCopyDone which
1954   // consumes the context. Also note that MakeTuple copies its arguments, so
1955   // this is memory-safe.
1956   const Literal context_literal = LiteralUtil::CreateR0<uint32>(0);
1957   evaluated_[copy_start] = LiteralUtil::MakeTuple(
1958       {&GetEvaluatedLiteralFor(copy_start->operand(0)),
1959        &GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
1960   return Status::OK();
1961 }
1962 
HandleCopyDone(HloInstruction * copy_done)1963 Status HloEvaluator::HandleCopyDone(HloInstruction* copy_done) {
1964   const HloInstruction* operand = copy_done->operand(0);
1965   if (operand->opcode() != HloOpcode::kCopyStart) {
1966     return tensorflow::errors::FailedPrecondition(
1967         "Cannot evaluate a kCopyDone that doesn't have a kCopyStart as "
1968         "operand.");
1969   }
1970 
1971   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1972 
1973   evaluated_[copy_done] =
1974       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), /*index=*/0));
1975   TF_RETURN_IF_ERROR(evaluated_[copy_done].CopyFrom(operand_tuple_literal,
1976                                                     /*dest_shape_index=*/{},
1977                                                     /*src_shape_index=*/{0}));
1978   return Status::OK();
1979 }
1980 
HandleCall(HloInstruction * call)1981 Status HloEvaluator::HandleCall(HloInstruction* call) {
1982   auto* computation = call->to_apply();
1983   auto operands = call->operands();
1984 
1985   std::vector<const Literal*> arg_literals;
1986   arg_literals.reserve(operands.size());
1987   for (auto operand : operands) {
1988     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1989     arg_literals.push_back(&arg_literal);
1990   }
1991 
1992   HloEvaluator embedded_evaluator;
1993   embedded_evaluator.set_dynamic_dimension_inference(
1994       dynamic_dimension_inference_);
1995   TF_ASSIGN_OR_RETURN(Literal result,
1996                       embedded_evaluator.Evaluate(*computation, arg_literals));
1997 
1998   evaluated_[call] = std::move(result);
1999   return Status::OK();
2000 }
2001 
HandleFusion(HloInstruction * fusion)2002 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
2003   HloModuleConfig config;
2004   // Attach cloned computation to an empty HLO module so the existing ones are
2005   // not modified.
2006   HloModule empty_hlo_module("EmptyModuleForFusion", config);
2007   HloCloneContext context(&empty_hlo_module);
2008   auto cloned_fused_computation =
2009       fusion->fused_instructions_computation()->Clone(
2010           /*suffix=*/"clone_with_layout", &context);
2011   for (auto* instruction : cloned_fused_computation->instructions()) {
2012     if (!LayoutUtil::HasLayout(instruction->shape())) {
2013       LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
2014     }
2015   }
2016   auto readded_computation =
2017       empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
2018 
2019   auto operands = fusion->operands();
2020   std::vector<const Literal*> arg_literals;
2021   arg_literals.reserve(operands.size());
2022   for (auto operand : operands) {
2023     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
2024     arg_literals.push_back(&arg_literal);
2025   }
2026 
2027   HloEvaluator embedded_evaluator;
2028   embedded_evaluator.set_dynamic_dimension_inference(
2029       dynamic_dimension_inference_);
2030   TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
2031                                           *readded_computation, arg_literals));
2032 
2033   evaluated_[fusion] = std::move(result);
2034   return Status::OK();
2035 }
2036 
HandleConditional(HloInstruction * conditional)2037 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
2038   const auto& branch_index_literal =
2039       GetEvaluatedLiteralFor(conditional->operand(0));
2040   int branch_index;
2041   if (conditional->operand(0)->shape().element_type() == PRED) {
2042     branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
2043   } else {
2044     branch_index = branch_index_literal.Get<int32>({});
2045     if (branch_index < 0 || branch_index >= conditional->branch_count()) {
2046       branch_index = conditional->branch_count() - 1;
2047     }
2048   }
2049   const auto& branch_computation_arg =
2050       GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
2051 
2052   HloEvaluator embedded_evaluator;
2053   embedded_evaluator.set_dynamic_dimension_inference(
2054       dynamic_dimension_inference_);
2055   TF_ASSIGN_OR_RETURN(Literal result,
2056                       embedded_evaluator.Evaluate(
2057                           *conditional->branch_computation(branch_index),
2058                           {&branch_computation_arg}));
2059 
2060   evaluated_[conditional] = std::move(result);
2061   return Status::OK();
2062 }
2063 
HandleSelect(HloInstruction * select)2064 Status HloEvaluator::HandleSelect(HloInstruction* select) {
2065   const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
2066   const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
2067   const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
2068 
2069   // If predicate is of scalar type, no element-wise selection would be needed.
2070   if (ShapeUtil::IsScalar(pred.shape())) {
2071     if (pred.Get<bool>({})) {
2072       evaluated_[select] = on_true.Clone();
2073     } else {
2074       evaluated_[select] = on_false.Clone();
2075     }
2076     return Status::OK();
2077   }
2078 
2079   return DefaultAction(select);
2080 }
2081 
HandleTupleSelect(HloInstruction * tuple_select)2082 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
2083   const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
2084   const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
2085   const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
2086 
2087   if (pred.Get<bool>({})) {
2088     evaluated_[tuple_select] = on_true.Clone();
2089   } else {
2090     evaluated_[tuple_select] = on_false.Clone();
2091   }
2092   return Status::OK();
2093 }
2094 
HandleWhile(HloInstruction * while_hlo)2095 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
2096   HloComputation* cond_comp = while_hlo->while_condition();
2097   HloComputation* body_comp = while_hlo->while_body();
2098   // Initialize the loop carried valued with the input to the While instruction.
2099   auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
2100   bool keep_going = true;
2101   int64_t iteration_count = 0;
2102   HloEvaluator cond_evaluator(max_loop_iterations_);
2103   cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
2104   HloEvaluator loop_body_evaluator(max_loop_iterations_);
2105   loop_body_evaluator.set_dynamic_dimension_inference(
2106       dynamic_dimension_inference_);
2107   while (keep_going) {
2108     if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
2109       return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
2110                              while_hlo->name(), max_loop_iterations_);
2111     }
2112     TF_ASSIGN_OR_RETURN(auto cond_val,
2113                         cond_evaluator.Evaluate(*cond_comp, {&lcv}));
2114     keep_going = cond_val.GetFirstElement<bool>();
2115     if (keep_going) {
2116       TF_ASSIGN_OR_RETURN(auto body_val,
2117                           loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
2118       VLOG(3) << "Loop iteration result: " << body_val.ToString();
2119       lcv = std::move(body_val);
2120       cond_evaluator.ResetVisitStates();
2121       loop_body_evaluator.ResetVisitStates();
2122     }
2123   }
2124   evaluated_[while_hlo] = std::move(lcv);
2125   return Status::OK();
2126 }
2127 
2128 namespace {
2129 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar)2130 Literal ExtractLiteralFromIndexPositions(const Literal& from,
2131                                          absl::Span<int64 const> indices,
2132                                          bool extract_as_scalar) {
2133   if (extract_as_scalar) {
2134     return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
2135   }
2136   // We use a InlinedVector here because we need to convert it to an
2137   // absl::Span later, and this would not work with std::vector<bool>.
2138   absl::InlinedVector<NativeT, 10> values;
2139   for (int64_t index : indices) {
2140     values.push_back(from.Get<NativeT>({index}));
2141   }
2142   return LiteralUtil::CreateR1<NativeT>(values);
2143 }
2144 
ExtractFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar=false)2145 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
2146                                             absl::Span<int64 const> indices,
2147                                             bool extract_as_scalar = false) {
2148   if (extract_as_scalar) {
2149     CHECK_EQ(indices.size(), 1);
2150   }
2151   PrimitiveType type = from.shape().element_type();
2152   switch (type) {
2153     case PRED: {
2154       return ExtractLiteralFromIndexPositions<bool>(from, indices,
2155                                                     extract_as_scalar);
2156     }
2157     case U8: {
2158       return ExtractLiteralFromIndexPositions<uint8>(from, indices,
2159                                                      extract_as_scalar);
2160     }
2161     case S8: {
2162       return ExtractLiteralFromIndexPositions<int8>(from, indices,
2163                                                     extract_as_scalar);
2164     }
2165     case BF16: {
2166       return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
2167                                                         extract_as_scalar);
2168     }
2169     case F16: {
2170       return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
2171                                                            extract_as_scalar);
2172     }
2173     case U16: {
2174       return ExtractLiteralFromIndexPositions<uint16>(from, indices,
2175                                                       extract_as_scalar);
2176     }
2177     case S16: {
2178       return ExtractLiteralFromIndexPositions<int16>(from, indices,
2179                                                      extract_as_scalar);
2180     }
2181     case F32: {
2182       return ExtractLiteralFromIndexPositions<float>(from, indices,
2183                                                      extract_as_scalar);
2184     }
2185     case U32: {
2186       return ExtractLiteralFromIndexPositions<uint32>(from, indices,
2187                                                       extract_as_scalar);
2188     }
2189     case S32: {
2190       return ExtractLiteralFromIndexPositions<int32>(from, indices,
2191                                                      extract_as_scalar);
2192     }
2193     case F64: {
2194       return ExtractLiteralFromIndexPositions<double>(from, indices,
2195                                                       extract_as_scalar);
2196     }
2197     case C64: {
2198       return ExtractLiteralFromIndexPositions<std::complex<float>>(
2199           from, indices, extract_as_scalar);
2200     }
2201     case U64: {
2202       return ExtractLiteralFromIndexPositions<uint64>(from, indices,
2203                                                       extract_as_scalar);
2204     }
2205     case S64: {
2206       return ExtractLiteralFromIndexPositions<int64>(from, indices,
2207                                                      extract_as_scalar);
2208     }
2209     case C128: {
2210       return ExtractLiteralFromIndexPositions<std::complex<double>>(
2211           from, indices, extract_as_scalar);
2212     }
2213     default:
2214       return InvalidArgument("Unsupported type for Sort: %s",
2215                              PrimitiveType_Name(type));
2216   }
2217 }
2218 }  // namespace
2219 
HandleSort(HloInstruction * sort)2220 Status HloEvaluator::HandleSort(HloInstruction* sort) {
2221   TF_RET_CHECK(sort->operand_count() >= 1)
2222       << "Expected at least 1 operand for sort";
2223   for (int64_t i = 1; i < sort->operand_count(); ++i) {
2224     TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
2225                                            sort->operand(i)->shape()))
2226         << "All Sort operands must have the same dimensions";
2227   }
2228 
2229   if (VLOG_IS_ON(3)) {
2230     for (int64_t i = 0; i < sort->operand_count(); ++i) {
2231       VLOG(3) << "HandleSort operand " << i << " literal: "
2232               << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
2233     }
2234   }
2235   Shape key_shape = sort->operand(0)->shape();
2236   auto rank = key_shape.rank();
2237   std::vector<Literal> result_literals;
2238   result_literals.reserve(sort->operand_count());
2239   for (int64_t i = 0; i < sort->operand_count(); ++i) {
2240     result_literals.emplace_back(sort->operand(i)->shape());
2241   }
2242   std::vector<int64> zero_base(rank, 0);
2243   std::vector<int64> increment(rank, 1);
2244   int64_t sort_dim = sort->dimensions(0);
2245   int64_t sort_dim_elements = key_shape.dimensions(sort_dim);
2246   increment[sort_dim] = sort_dim_elements;
2247   HloEvaluator embedded_evaluator(max_loop_iterations_);
2248   // Iterate through each dimension except 'sort_dim'.
2249   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2250       key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
2251       [&](absl::Span<const int64> indices) -> StatusOr<bool> {
2252         // Extract a slice from each operand literal that corresponds to
2253         // exactly the row in dimension 'sort_dim'.
2254         std::vector<int64> limit_indices(indices.begin(), indices.end());
2255         absl::c_for_each(limit_indices, [](int64& index) { ++index; });
2256         limit_indices[sort_dim] = sort_dim_elements;
2257         std::vector<Literal> literals_to_sort;
2258         literals_to_sort.reserve(sort->operand_count());
2259         for (int64_t i = 0; i < sort->operand_count(); ++i) {
2260           TF_ASSIGN_OR_RETURN(auto literal_to_sort,
2261                               GetEvaluatedLiteralFor(sort->operand(i))
2262                                   .Slice(indices, limit_indices)
2263                                   .Reshape({sort_dim_elements}));
2264           literals_to_sort.push_back(std::move(literal_to_sort));
2265         }
2266         std::vector<int64> indices_to_sort(sort_dim_elements);
2267         std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
2268         Status compare_status = Status::OK();
2269         auto comparator = [sort, &compare_status, &embedded_evaluator,
2270                            &literals_to_sort](int64_t a, int64_t b) {
2271           std::vector<Literal> literals;
2272           literals.reserve(2 * sort->operand_count());
2273           for (int64_t i = 0; i < sort->operand_count(); ++i) {
2274             auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
2275                                                  /*extract_as_scalar=*/true);
2276             if (!lhs.ok()) {
2277               compare_status = lhs.status();
2278               return false;
2279             }
2280             literals.push_back(std::move(lhs.ValueOrDie()));
2281             auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
2282                                                  /*extract_as_scalar=*/true);
2283             if (!rhs.ok()) {
2284               compare_status = rhs.status();
2285               return false;
2286             }
2287             literals.push_back(std::move(rhs.ValueOrDie()));
2288           }
2289           std::vector<const Literal*> literal_ptrs;
2290           absl::c_transform(literals, std::back_inserter(literal_ptrs),
2291                             [](const Literal& literal) { return &literal; });
2292 
2293           auto computed_result =
2294               embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
2295           // Clear visit states so that we can use the evaluator again
2296           // on the same computation.
2297           embedded_evaluator.ResetVisitStates();
2298           if (!computed_result.ok()) {
2299             compare_status = computed_result.status();
2300             return false;
2301           }
2302           return computed_result.ValueOrDie().Get<bool>({});
2303         };
2304         if (Cast<HloSortInstruction>(sort)->is_stable()) {
2305           std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
2306                            comparator);
2307         } else {
2308           std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
2309         }
2310         if (!compare_status.ok()) {
2311           return compare_status;
2312         }
2313         std::vector<int64> slice_dimensions(rank, 1);
2314         slice_dimensions[sort_dim] = sort_dim_elements;
2315         std::vector<int64> start_indices(rank, 0);
2316         for (int64_t i = 0; i < sort->operand_count(); ++i) {
2317           TF_ASSIGN_OR_RETURN(
2318               Literal sorted_literal,
2319               ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
2320           TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
2321                               sorted_literal.Reshape(slice_dimensions));
2322           TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
2323               sorted_literal_reshaped, start_indices, indices,
2324               slice_dimensions));
2325         }
2326         return true;
2327       }));
2328 
2329   if (sort->operand_count() == 1) {
2330     evaluated_[sort] = std::move(result_literals[0]);
2331   } else {
2332     std::vector<const Literal*> literal_ptrs;
2333     absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
2334                       [](const Literal& literal) { return &literal; });
2335 
2336     Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
2337     VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
2338 
2339     evaluated_[sort] = std::move(result_tuple);
2340   }
2341   return Status::OK();
2342 }
2343 
IsScalarAdd(HloComputation * computation)2344 static bool IsScalarAdd(HloComputation* computation) {
2345   HloInstruction* instruction = computation->root_instruction();
2346   if (instruction->opcode() == HloOpcode::kAdd &&
2347       computation->num_parameters() == 2) {
2348     const HloInstruction* lhs = instruction->operand(0);
2349     const HloInstruction* rhs = instruction->operand(1);
2350     return lhs->opcode() == HloOpcode::kParameter &&
2351            ShapeUtil::IsScalar(lhs->shape()) &&
2352            rhs->opcode() == HloOpcode::kParameter &&
2353            ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
2354   }
2355   return false;
2356 }
2357 
2358 // Run a single step of an inner loop while running reduction, which applies
2359 // the user-provided computation on the accumulator and the output element
2360 // (until the reduction is completed, the output element is also used as
2361 // an accumulator).
PerformReductionStep(bool is_tuple,absl::Span<const int64> input_index,absl::Span<const int64> output_index,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * computation,HloEvaluator * embedded_evaluator)2362 static StatusOr<bool> PerformReductionStep(
2363     bool is_tuple, absl::Span<const int64> input_index,
2364     absl::Span<const int64> output_index,
2365     absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2366     HloComputation* computation, HloEvaluator* embedded_evaluator) {
2367   int num_args = results.size();
2368 
2369   absl::InlinedVector<Literal, 1> arg_values;
2370   arg_values.reserve(num_args);
2371   absl::InlinedVector<Literal, 1> accumulators;
2372   accumulators.reserve(num_args);
2373   for (int64_t i = 0; i < num_args; ++i) {
2374     arg_values.emplace_back(
2375         ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2376     accumulators.emplace_back(
2377         ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2378 
2379     TF_RETURN_IF_ERROR(
2380         arg_values[i].CopyElementFrom(*input_args[i], input_index, {}));
2381     TF_RETURN_IF_ERROR(
2382         accumulators[i].CopyElementFrom(results[i], output_index, {}));
2383   }
2384 
2385   // Evaluate computation with specified literal operands.
2386   absl::InlinedVector<Literal*, 2> embedded_operands;
2387   for (Literal& accumulator : accumulators) {
2388     embedded_operands.push_back(&accumulator);
2389   }
2390   for (Literal& local_input : arg_values) {
2391     embedded_operands.push_back(&local_input);
2392   }
2393 
2394   TF_ASSIGN_OR_RETURN(
2395       Literal computed_result,
2396       embedded_evaluator->Evaluate(*computation, embedded_operands));
2397 
2398   // Clear visit states so that we can use the evaluator again on the same
2399   // computation.
2400   embedded_evaluator->ResetVisitStates();
2401 
2402   if (is_tuple) {
2403     std::vector<Literal> computed_results = computed_result.DecomposeTuple();
2404     for (int64_t i = 0; i < num_args; ++i) {
2405       TF_RETURN_IF_ERROR(
2406           results[i].CopyElementFrom(computed_results[i], {}, output_index));
2407     }
2408   } else {
2409     TF_RETURN_IF_ERROR(
2410         results[0].CopyElementFrom(computed_result, {}, output_index));
2411   }
2412 
2413   return true;
2414 }
2415 
GenerateReduceOutputElement(bool is_tuple,absl::Span<const int64> output_index,absl::Span<const Literal * const> init_values,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * function,HloEvaluator * embedded_evaluator,absl::Span<const int64> arg_dim_steps,absl::Span<const int64> arg_dim_counts,absl::Span<const int64> result_to_arg_index)2416 static StatusOr<bool> GenerateReduceOutputElement(
2417     bool is_tuple, absl::Span<const int64> output_index,
2418 
2419     absl::Span<const Literal* const> init_values,
2420     absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2421 
2422     HloComputation* function, HloEvaluator* embedded_evaluator,
2423 
2424     absl::Span<const int64> arg_dim_steps,
2425     absl::Span<const int64> arg_dim_counts,
2426     absl::Span<const int64> result_to_arg_index) {
2427   bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) &&
2428                       IsScalarAdd(function) && !is_tuple;
2429 
2430   const Shape& arg_shape = input_args[0]->shape();
2431   absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2432   std::vector<int64> base(arg_dimensions.size());
2433   for (int64_t i = 0; i < output_index.size(); ++i) {
2434     base[result_to_arg_index[i]] = output_index[i];
2435   }
2436 
2437   for (int64_t i = 0; i < results.size(); ++i) {
2438     TF_RETURN_IF_ERROR(
2439         results[i].CopyElementFrom(*init_values[i], {}, output_index));
2440   }
2441 
2442   if (use_fast_add) {
2443     double computed_result = *init_values[0]->GetAsDouble({});
2444     auto reduction_step =
2445         [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
2446       double argument = *input_args[0]->GetAsDouble(input_index);
2447       computed_result += argument;
2448       return true;
2449     };
2450     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2451         arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step));
2452     TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result));
2453     return true;
2454   }
2455 
2456   // Iterates only over reduced shape, as counts and steps are set to zero
2457   // for all non-reduced dimensions.
2458   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2459       arg_shape, base, arg_dim_counts, arg_dim_steps,
2460       [&](absl::Span<const int64> input_index) {
2461         return PerformReductionStep(is_tuple, input_index, output_index,
2462                                     input_args, results, function,
2463                                     embedded_evaluator);
2464       }));
2465   return true;
2466 }
2467 
HandleReduce(HloInstruction * instr)2468 Status HloEvaluator::HandleReduce(HloInstruction* instr) {
2469   HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
2470   int64_t num_args = reduce->inputs().size();
2471   absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
2472   HloComputation* function = reduce->to_apply();
2473 
2474   absl::InlinedVector<const Shape*, 1> operand_shapes;
2475   for (const HloInstruction* operand : reduce->operands()) {
2476     operand_shapes.push_back(&operand->shape());
2477   }
2478   TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
2479                       ShapeInference::InferReduceShape(
2480                           operand_shapes, dimensions_to_reduce,
2481                           /*to_apply=*/function->ComputeProgramShape()));
2482   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(),
2483                                                         inferred_return_shape))
2484       << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
2485       << " but is inferred to be: "
2486       << ShapeUtil::HumanString(inferred_return_shape);
2487 
2488   absl::InlinedVector<const Literal*, 1> input_args(num_args);
2489   absl::InlinedVector<const Literal*, 1> init_values(num_args);
2490   for (int64_t i = 0; i < num_args; ++i) {
2491     input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]);
2492     VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString();
2493     init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]);
2494     VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString();
2495     TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape()));
2496   }
2497 
2498   // All args and results have the same dimensions, so pick an arbitrary one.
2499   const Shape& arg_shape = input_args[0]->shape();
2500   const Shape& out_shape = inferred_return_shape;
2501   bool is_tuple = out_shape.IsTuple();
2502   const Shape& output_shape = inferred_return_shape.IsTuple()
2503                                   ? inferred_return_shape.tuple_shapes(0)
2504                                   : inferred_return_shape;
2505 
2506   absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2507 
2508   // All increments are set to 0.
2509   std::vector<int64> arg_dim_steps(arg_dimensions.size());
2510 
2511   // All counts are set to 0.
2512   std::vector<int64> arg_dim_counts(arg_dimensions.size());
2513 
2514   // Set steps and counts for reduced dimensions.
2515   // This avoids iterating over non-reduced dimensions, as their step
2516   // and count is set to zero.
2517   for (const int64_t dim : dimensions_to_reduce) {
2518     arg_dim_steps[dim] = 1;
2519     arg_dim_counts[dim] = arg_dimensions[dim];
2520   }
2521 
2522   // Map each dimension in the result to a dimension in arg that isn't
2523   // being reduced.
2524   std::vector<int64> result_to_arg_index;
2525   for (int64_t i = 0; i < arg_dimensions.size(); ++i) {
2526     if (arg_dim_steps[i] == 0) {
2527       result_to_arg_index.push_back(i);
2528     }
2529   }
2530 
2531   HloEvaluator embedded_evaluator(max_loop_iterations_);
2532   absl::InlinedVector<Literal, 1> results(num_args);
2533   for (int64_t i = 0; i < num_args; ++i) {
2534     results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape);
2535   }
2536 
2537   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2538       output_shape, [&](absl::Span<const int64> output_index) {
2539         return GenerateReduceOutputElement(
2540             is_tuple, output_index, init_values, input_args,
2541             absl::Span<Literal>(results), function, &embedded_evaluator,
2542             arg_dim_steps, arg_dim_counts, result_to_arg_index);
2543       }));
2544 
2545   if (is_tuple) {
2546     Literal tuple_result(inferred_return_shape);
2547     for (int64_t i = 0; i < num_args; ++i) {
2548       TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
2549     }
2550     evaluated_[reduce] = std::move(tuple_result);
2551   } else {
2552     CHECK_EQ(results.size(), 1);
2553     evaluated_[reduce] = std::move(results[0]);
2554   }
2555   if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) {
2556     TF_ASSIGN_OR_RETURN(evaluated_[reduce],
2557                         evaluated_[reduce].ConvertToShape(reduce->shape()));
2558   }
2559   return Status::OK();
2560 }
2561 
HandleReduceWindow(HloInstruction * hlo)2562 Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
2563   // Here we delegate the handling to the typed visitor class, instantiated by
2564   // using the type of the first input of ReduceWindow. The support for the
2565   // variadic case inside the typed_visitor is made to not use the template
2566   // parameter so it doesn't really matter which type is used to instantiate it
2567   // here. We choose not to move the implementation for handle ReduceWindow
2568   // from the typed visitor to here because we need to reuse the
2569   // IterateThroughWindow method, which is defined and only avaiable inside the
2570   // typed visitor.
2571   if (hlo->shape().IsTuple()) {
2572     return hlo->Visit(
2573         typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
2574   } else {
2575     return DefaultAction(hlo);
2576   }
2577 }
2578 
HandleCustomCall(HloInstruction * custom_call)2579 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
2580   if (!custom_call_handler_) {
2581     // No handler is registered; this means custom-calls are not allowed.
2582     return DefaultAction(custom_call);
2583   }
2584 
2585   // Evaluate input operands so the handler has access to the operand data.
2586   std::vector<const Literal*> operands;
2587   operands.reserve(custom_call->operand_count());
2588   for (const HloInstruction* operand : custom_call->operands()) {
2589     operands.push_back(&GetEvaluatedLiteralFor(operand));
2590   }
2591 
2592   // Synchronously issue the handler to populate the instruction output literal.
2593   TF_ASSIGN_OR_RETURN(
2594       auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
2595 
2596   evaluated_[custom_call] = std::move(output);
2597   return Status::OK();
2598 }
2599 
Preprocess(HloInstruction * hlo)2600 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
2601   VLOG(2) << "About to visit HLO: " << hlo->ToString();
2602   return ShapeUtil::ValidateShape(hlo->shape());
2603 }
2604 
Postprocess(HloInstruction * hlo)2605 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
2606   VLOG(2) << "Finished visiting " << hlo->ToString()
2607           << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
2608   // Out of convenience the literal may have been produced with a different
2609   // layout. Relayout as indicated by the HLO instruction.
2610   if (!Layout::Equal().MinorToMajorOnly()(
2611           GetEvaluatedLiteralFor(hlo).shape().layout(),
2612           hlo->shape().layout())) {
2613     evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
2614   }
2615   return Status::OK();
2616 }
2617 
2618 namespace {
2619 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)> & impl_fn)2620 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
2621     const Array2D<T>& lhs, const Array2D<T>& rhs,
2622     const std::function<void(const void* run_options_ptr, T* out, T* lhs,
2623                              T* rhs, int64_t m, int64_t n, int64_t k,
2624                              int32_t transpose_lhs, int32_t transpose_rhs)>&
2625         impl_fn) {
2626   CHECK_EQ(lhs.width(), rhs.height());
2627   int m = lhs.height();
2628   int n = rhs.width();
2629   int k = lhs.width();
2630   auto result = absl::make_unique<Array2D<T>>(m, n);
2631   // Because Eigen is a header-oriented library, make sure that the Eigen code
2632   // is the same as the code used by the CPU backend (otherwise the linker will
2633   // randomly pick *some* definition).
2634   impl_fn(
2635       /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
2636       k,
2637       /*transpose_lhs=*/0,
2638       /*transpose_rhs=*/0);
2639   return result;
2640 }
2641 }  // namespace
2642 
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)2643 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
2644     const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
2645   return MatmulArray2DImpl<Eigen::half>(
2646       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
2647 }
2648 
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)2649 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
2650     const Array2D<float>& lhs, const Array2D<float>& rhs) {
2651   return MatmulArray2DImpl<float>(
2652       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
2653 }
2654 
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)2655 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
2656     const Array2D<double>& lhs, const Array2D<double>& rhs) {
2657   return MatmulArray2DImpl<double>(
2658       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
2659 }
2660 
MatmulArray2D(const Array2D<std::complex<float>> & lhs,const Array2D<std::complex<float>> & rhs)2661 std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
2662     const Array2D<std::complex<float>>& lhs,
2663     const Array2D<std::complex<float>>& rhs) {
2664   return MatmulArray2DImpl<std::complex<float>>(
2665       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
2666 }
2667 
MatmulArray2D(const Array2D<std::complex<double>> & lhs,const Array2D<std::complex<double>> & rhs)2668 std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
2669     const Array2D<std::complex<double>>& lhs,
2670     const Array2D<std::complex<double>>& rhs) {
2671   return MatmulArray2DImpl<std::complex<double>>(
2672       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
2673 }
2674 
MatmulArray2D(const Array2D<int32> & lhs,const Array2D<int32> & rhs)2675 std::unique_ptr<Array2D<int32>> HloEvaluator::MatmulArray2D(
2676     const Array2D<int32>& lhs, const Array2D<int32>& rhs) {
2677   return MatmulArray2DImpl<int32>(
2678       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulS32);
2679 }
2680 
2681 }  // namespace xla
2682