• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <complex>
20 #include <cstdint>
21 #include <cstdlib>
22 #include <functional>
23 #include <iterator>
24 #include <memory>
25 #include <optional>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/algorithm/container.h"
32 #include "absl/base/internal/endian.h"
33 #include "absl/cleanup/cleanup.h"
34 #include "absl/container/inlined_vector.h"
35 #include "absl/strings/match.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "tensorflow/compiler/xla/index_util.h"
39 #include "tensorflow/compiler/xla/layout_util.h"
40 #include "tensorflow/compiler/xla/literal.h"
41 #include "tensorflow/compiler/xla/literal_util.h"
42 #include "tensorflow/compiler/xla/map_util.h"
43 #include "tensorflow/compiler/xla/primitive_util.h"
44 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
45 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
46 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
47 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_query.h"
50 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
51 #include "tensorflow/compiler/xla/service/shape_inference.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/status_macros.h"
54 #include "tensorflow/compiler/xla/statusor.h"
55 #include "tensorflow/compiler/xla/types.h"
56 #include "tensorflow/compiler/xla/util.h"
57 #include "tensorflow/compiler/xla/window_util.h"
58 #include "tensorflow/core/lib/core/bitmap.h"
59 #include "tensorflow/core/lib/core/errors.h"
60 #include "tensorflow/core/lib/core/status.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/logging.h"
63 #include "tensorflow/core/platform/protobuf.h"
64 #include "tensorflow/core/platform/status.h"
65 #include "tensorflow/core/platform/statusor.h"
66 #include "tensorflow/core/platform/types.h"
67 #include "tensorflow/stream_executor/lib/statusor.h"
68 
69 namespace xla {
70 
71 namespace {
72 
73 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)74 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
75                           LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
76   std::function<bool(OperandT, OperandT)> compare_op;
77   switch (direction) {
78     case ComparisonDirection::kEq:
79       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
80         return lhs_el == rhs_el;
81       };
82       break;
83     case ComparisonDirection::kNe:
84       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
85         return lhs_el != rhs_el;
86       };
87       break;
88     case ComparisonDirection::kGe:
89       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
90         return lhs_el >= rhs_el;
91       };
92       break;
93     case ComparisonDirection::kGt:
94       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
95         return lhs_el > rhs_el;
96       };
97       break;
98     case ComparisonDirection::kLe:
99       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
100         return lhs_el <= rhs_el;
101       };
102       break;
103     case ComparisonDirection::kLt:
104       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
105         return lhs_el < rhs_el;
106       };
107       break;
108   }
109 
110   Literal result(shape);
111   TF_RETURN_IF_ERROR(
112       result.Populate<bool>([&](absl::Span<const int64_t> multi_index) {
113         return compare_op(lhs_literal.Get<OperandT>(multi_index),
114                           rhs_literal.Get<OperandT>(multi_index));
115       }));
116 
117   return std::move(result);
118 }
119 
120 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)121 StatusOr<Literal> Compare<complex64>(const Shape& shape,
122                                      ComparisonDirection direction,
123                                      LiteralSlice lhs_literal,
124                                      LiteralSlice rhs_literal) {
125   std::function<bool(complex64, complex64)> compare_op;
126   switch (direction) {
127     case ComparisonDirection::kEq:
128       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
129         return lhs_el == rhs_el;
130       };
131       break;
132     case ComparisonDirection::kNe:
133       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
134         return lhs_el != rhs_el;
135       };
136       break;
137     default:
138       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
139                  << ComparisonDirectionToString(direction);
140   }
141 
142   Literal result(shape);
143   TF_RETURN_IF_ERROR(
144       result.Populate<bool>([&](absl::Span<const int64_t> multi_index) {
145         return compare_op(lhs_literal.Get<complex64>(multi_index),
146                           rhs_literal.Get<complex64>(multi_index));
147       }));
148 
149   return std::move(result);
150 }
151 
152 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)153 StatusOr<Literal> Compare<complex128>(const Shape& shape,
154                                       ComparisonDirection direction,
155                                       LiteralSlice lhs_literal,
156                                       LiteralSlice rhs_literal) {
157   std::function<bool(complex128, complex128)> compare_op;
158   switch (direction) {
159     case ComparisonDirection::kEq:
160       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
161         return lhs_el == rhs_el;
162       };
163       break;
164     case ComparisonDirection::kNe:
165       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
166         return lhs_el != rhs_el;
167       };
168       break;
169     default:
170       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
171                  << ComparisonDirectionToString(direction);
172   }
173 
174   Literal result(shape);
175   TF_RETURN_IF_ERROR(
176       result.Populate<bool>([&](absl::Span<const int64_t> multi_index) {
177         return compare_op(lhs_literal.Get<complex128>(multi_index),
178                           rhs_literal.Get<complex128>(multi_index));
179       }));
180 
181   return std::move(result);
182 }
183 
184 // Represents an index into the while argument tuple and / or a value.
185 // At least one of param_index and value has a value; both of them could have
186 // a value.
187 struct ParamIndexAndValue {
188   std::optional<int64_t> param_index;
189   std::optional<int64_t> value;
190 
IsValidxla::__anon9f8d6ac10111::ParamIndexAndValue191   bool IsValid() const { return param_index.has_value() || value.has_value(); }
192 };
193 
194 // Represents the while loop condition comparison.
195 // We assume comparison is of the form: lhs comp rhs.
196 struct WhileCondComparison {
197   ComparisonDirection comparson_direction;
198   ParamIndexAndValue lhs;
199   ParamIndexAndValue rhs;
200 };
201 
202 // Represents the parsed while loop condition. The loop induction variable may
203 // either be used in a comparison or returned directly, i.e., NoOp. In the case
204 // of NoOp, it contains the parameter index and initial value of the loop
205 // induction variable.
206 using WhileCondComparisonOrNoOp =
207     std::variant<WhileCondComparison, ParamIndexAndValue>;
208 
209 // Finds the while loop condition comparison by matching the loop condition root
210 // with known patterns.
PatternMatchLoopCondComparison(HloInstruction * loop_cond_root)211 std::optional<WhileCondComparisonOrNoOp> PatternMatchLoopCondComparison(
212     HloInstruction* loop_cond_root) {
213   // Base pattern #1: gte-0 comp gte-1
214   if (Match(loop_cond_root,
215             match::Compare()
216                 .WithOperand(0, match::GetTupleElement().WithOperand(
217                                     0, match::Parameter().WithParameterNum(0)))
218                 .WithOperand(1,
219                              match::GetTupleElement().WithOperand(
220                                  0, match::Parameter().WithParameterNum(0))))) {
221     return WhileCondComparison{
222         loop_cond_root->comparison_direction(),
223         {/*param_index=*/loop_cond_root->operand(0)->tuple_index()},
224         {/*param_index=*/loop_cond_root->operand(1)->tuple_index()}};
225   }
226   // Base pattern #2: constant comp gte
227   if (Match(loop_cond_root,
228             match::Compare()
229                 .WithOperand(0, match::Constant())
230                 .WithOperand(1,
231                              match::GetTupleElement().WithOperand(
232                                  0, match::Parameter().WithParameterNum(0))))) {
233     std::optional<int64_t> lhs_value =
234         loop_cond_root->operand(0)->literal().GetFirstInteger();
235     if (!lhs_value.has_value()) {
236       return std::nullopt;
237     }
238     return WhileCondComparison{
239         loop_cond_root->comparison_direction(),
240         {/*param_index=*/std::nullopt, /*value=*/*lhs_value},
241         {/*param_index=*/loop_cond_root->operand(1)->tuple_index()}};
242   }
243   // Base pattern #3: gte comp constant
244   if (Match(loop_cond_root,
245             match::Compare()
246                 .WithOperand(0, match::GetTupleElement().WithOperand(
247                                     0, match::Parameter().WithParameterNum(0)))
248                 .WithOperand(1, match::Constant()))) {
249     std::optional<int64_t> rhs_value =
250         loop_cond_root->operand(1)->literal().GetFirstInteger();
251     if (!rhs_value.has_value()) {
252       return std::nullopt;
253     }
254     return WhileCondComparison{
255         loop_cond_root->comparison_direction(),
256         {/*param_index=*/loop_cond_root->operand(0)->tuple_index(),
257          /*value=*/std::nullopt},
258         {/*param_index=*/std::nullopt, /*value=*/*rhs_value},
259     };
260   }
261   // Base pattern #4: gte is a boolean scalar and it was return immediately.
262   if (Match(loop_cond_root, match::GetTupleElement().WithOperand(
263                                 0, match::Parameter().WithParameterNum(0)))) {
264     if (loop_cond_root->shape().element_type() != PrimitiveType::PRED &&
265         loop_cond_root->shape().rank() != 0) {
266       return std::nullopt;
267     }
268     return ParamIndexAndValue{{/*param_index=*/loop_cond_root->tuple_index()}};
269   }
270 
271   // Recursive pattern #1:
272   // loop_cond_root is a GetTupleElement whose operand is a call with a single
273   // parameter which takes the computation's single parameter.
274   // In this case, if the called computation's root is a tuple, we can recurse
275   // on that tuple's element as the new loop_cond_root.
276   if (Match(loop_cond_root,
277             match::GetTupleElement().WithOperand(
278                 0, match::Call().WithNumOperands(1).WithOperand(
279                        0, match::Parameter().WithParameterNum(0))))) {
280     HloInstruction* call_instruction = loop_cond_root->mutable_operand(0);
281     HloComputation* to_apply = call_instruction->to_apply();
282     HloInstruction* to_apply_root = to_apply->root_instruction();
283     if (Match(to_apply_root, match::Tuple())) {
284       return PatternMatchLoopCondComparison(
285           to_apply_root->mutable_operand(loop_cond_root->tuple_index()));
286     }
287   }
288   // Recursive pattern #2:
289   // loop_cond_root is a GetTupleElement whose operand is a tuple.
290   // We can recurse on the tuple's element as the new loop_cond_root.
291   if (Match(loop_cond_root,
292             match::GetTupleElement().WithOperand(0, match::Tuple()))) {
293     HloInstruction* new_cond_root =
294         loop_cond_root->mutable_operand(0)->mutable_operand(
295             loop_cond_root->tuple_index());
296     return PatternMatchLoopCondComparison(new_cond_root);
297   }
298   return std::nullopt;
299 }
300 
301 // Tries to parse the loop body to find how the induction variable is updated
302 // using pattern matching.
PatternMatchInductionVarUpdate(HloInstruction * loop_body_root,int64_t tuple_index)303 std::optional<int64_t> PatternMatchInductionVarUpdate(
304     HloInstruction* loop_body_root, int64_t tuple_index) {
305   // Pattern #1: induc_var = induc_var + constant
306   if (Match(loop_body_root,
307             match::Tuple().WithOperand(
308                 tuple_index,
309                 match::Add()
310                     .WithOperand(0, match::GetTupleElement()
311                                         .WithTupleIndex(tuple_index)
312                                         .WithOperand(0, match::Parameter()))
313                     .WithOperand(1, match::Constant())))) {
314     std::optional<int64_t> step_size = loop_body_root->operand(tuple_index)
315                                            ->operand(1)
316                                            ->literal()
317                                            .GetFirstInteger();
318     if (!step_size.has_value()) {
319       return std::nullopt;
320     }
321     return *step_size;
322   }
323   // Pattern #2: induc_var = constant + induc_var
324   if (Match(
325           loop_body_root,
326           match::Tuple().WithOperand(
327               tuple_index,
328               match::Add()
329                   .WithOperand(0, match::Constant())
330                   .WithOperand(1, match::GetTupleElement()
331                                       .WithTupleIndex(tuple_index)
332                                       .WithOperand(0, match::Parameter()))))) {
333     std::optional<int64_t> step_size = loop_body_root->operand(tuple_index)
334                                            ->operand(0)
335                                            ->literal()
336                                            .GetFirstInteger();
337     if (!step_size.has_value()) {
338       return std::nullopt;
339     }
340     return *step_size;
341   }
342 
343   // Pattern #3: induc_var = induc_var - constant
344   if (Match(loop_body_root,
345             match::Tuple().WithOperand(
346                 tuple_index,
347                 match::Subtract()
348                     .WithOperand(0, match::GetTupleElement()
349                                         .WithTupleIndex(tuple_index)
350                                         .WithOperand(0, match::Parameter()))
351                     .WithOperand(1, match::Constant())))) {
352     std::optional<int64_t> step_size = loop_body_root->operand(tuple_index)
353                                            ->operand(1)
354                                            ->literal()
355                                            .GetFirstInteger();
356     if (!step_size.has_value()) {
357       return std::nullopt;
358     }
359     return -*step_size;
360   }
361 
362   // Pattern #4: the induc_var is directly returned from the loop body with
363   // no changes.
364   if (Match(loop_body_root,
365             match::Tuple().WithOperand(
366                 tuple_index,
367                 match::GetTupleElement()
368                     .WithOperand(0, match::Parameter().WithParameterNum(0))
369                     .WithTupleIndex(tuple_index)))) {
370     return 0;
371   }
372   return std::nullopt;
373 }
374 
PatternMatchLoopCondVarOverride(HloInstruction * loop_body_root,int64_t tuple_index)375 std::optional<bool> PatternMatchLoopCondVarOverride(
376     HloInstruction* loop_body_root, int64_t tuple_index) {
377   if (Match(loop_body_root, match::Tuple()) &&
378       loop_body_root->operand_count() > tuple_index) {
379     HloInstruction* cond_var_override =
380         loop_body_root->mutable_operand(tuple_index);
381     HloEvaluator evaluator;
382     StatusOr<Literal> new_cond_var = evaluator.Evaluate(
383         cond_var_override, /*recursively_evaluate_nonconstant_operands=*/true);
384     if (new_cond_var.ok()) {
385       return new_cond_var->GetFirstElement<bool>();
386     }
387   }
388   return std::nullopt;
389 }
390 
391 // Repesents a value that might or might not be determined statically.
392 struct DynamicOrStaticValue {
393   std::optional<int64_t> static_value;
is_dynamicxla::__anon9f8d6ac10111::DynamicOrStaticValue394   bool is_dynamic() const { return !static_value.has_value(); }
395 };
396 
397 constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
398 
399 // Use this class to represent the precise details of the error to enable
400 // special treatment.
401 enum class EvalErrorDetail : uint32_t {
402   // The evaluation result depends on dynamic values such as parameters and
403   // infeed. Therefore, the HLO's value cannot be statically evaluated.
404   kDynamicValueDependence = 0,
405 };
406 
MakeEvalErrorDueToParamOrInfeed(const HloInstruction & eval_instruction)407 Status MakeEvalErrorDueToParamOrInfeed(const HloInstruction& eval_instruction) {
408   Status error = tensorflow::errors::FailedPrecondition(
409       "Failed to evaluate instruction (", eval_instruction.name(),
410       ") since it depends on infeed or parameters to its parent computation (",
411       eval_instruction.parent()->name(), ").");
412   std::string error_payload;
413   error_payload.resize(sizeof(EvalErrorDetail));
414   absl::little_endian::Store32(
415       const_cast<char*>(error_payload.data()),
416       static_cast<uint32_t>(EvalErrorDetail::kDynamicValueDependence));
417   error.SetPayload(kEvalErrorDetailUrl, error_payload);
418   return error;
419 }
420 
ParseEvalErrorDetail(const Status & error)421 std::optional<EvalErrorDetail> ParseEvalErrorDetail(const Status& error) {
422   auto error_detail = error.GetPayload(kEvalErrorDetailUrl);
423   if (!error_detail.has_value() && error_detail->empty()) {
424     return std::nullopt;
425   }
426   return static_cast<EvalErrorDetail>(
427       absl::little_endian::Load32(error_detail->Flatten().data()));
428 }
429 
430 // A convenience wrapper to compute the while loop's argument's init value at
431 // the given tuple_index. If the init value depends on parameters to the
432 // while loop's parent computation or infeed, we consider the init value
433 // dynamic.
EvaluateWhileLoopParamInitValue(HloInstruction * param_instruction,int64_t tuple_index)434 std::optional<DynamicOrStaticValue> EvaluateWhileLoopParamInitValue(
435     HloInstruction* param_instruction, int64_t tuple_index) {
436   if (param_instruction->opcode() != HloOpcode::kTuple) {
437     return std::nullopt;
438   }
439   HloInstruction* element_instruction =
440       param_instruction->mutable_operand(tuple_index);
441   HloEvaluator evaluator;
442   StatusOr<Literal> value = evaluator.Evaluate(
443       element_instruction, /*recursively_evaluate_nonconstant_operands=*/true);
444   if (value.ok()) {
445     if (element_instruction->shape().element_type() == PrimitiveType::PRED) {
446       return DynamicOrStaticValue{
447           static_cast<int64_t>(value->GetFirstElement<bool>())};
448     } else {
449       return DynamicOrStaticValue{value->GetFirstInteger()};
450     }
451   } else {
452     std::optional<EvalErrorDetail> eval_error_detail =
453         ParseEvalErrorDetail(value.status());
454     if (eval_error_detail.has_value() &&
455         *eval_error_detail == EvalErrorDetail::kDynamicValueDependence) {
456       return DynamicOrStaticValue{std::nullopt};
457     }
458   }
459   return std::nullopt;
460 }
461 
462 }  // namespace
463 
PatternMatchParseWhileLoop(HloInstruction * while_op)464 std::optional<ParsedWhileLoop> PatternMatchParseWhileLoop(
465     HloInstruction* while_op) {
466   HloComputation* while_cond = while_op->while_condition();
467   HloComputation* while_body = while_op->while_body();
468   HloInstruction* while_operand = while_op->mutable_operand(0);
469   // Try to parse the loop condition comparison.
470   std::optional<WhileCondComparisonOrNoOp> loop_comparison_or_noop =
471       PatternMatchLoopCondComparison(while_cond->root_instruction());
472   if (!loop_comparison_or_noop.has_value()) {
473     return std::nullopt;
474   }
475   if (loop_comparison_or_noop->index() == 1) {
476     ParamIndexAndValue& parameter_index_and_value =
477         std::get<ParamIndexAndValue>(*loop_comparison_or_noop);
478     CHECK(parameter_index_and_value.param_index.has_value());
479     int64_t loop_cond_var_index = *parameter_index_and_value.param_index;
480     std::optional<DynamicOrStaticValue> noop_value =
481         EvaluateWhileLoopParamInitValue(while_operand, loop_cond_var_index);
482 
483     if (noop_value.has_value()) {
484       if (noop_value->is_dynamic()) {
485         return kParsedDynamicWhileLoop;
486       } else if (*noop_value->static_value == 0) {
487         return ParsedWhileLoop{
488             ParsedStaticWhileLoop{/*trip_count=*/0,
489                                   /*induction_var_index=*/loop_cond_var_index,
490                                   /*induction_var_init_value=*/0,
491                                   /*step_size=*/0,
492                                   /*loop_bound=*/0}};
493       }
494       std::optional<bool> updated_loop_cond_var =
495           PatternMatchLoopCondVarOverride(while_body->root_instruction(),
496                                           loop_cond_var_index);
497       if (updated_loop_cond_var.has_value()) {
498         if (!*updated_loop_cond_var) {
499           return ParsedWhileLoop{
500               ParsedStaticWhileLoop{/*trip_count=*/1,
501                                     /*induction_var_index=*/loop_cond_var_index,
502                                     /*induction_var_init_value=*/0,
503                                     /*step_size=*/1,
504                                     /*loop_bound=*/1}};
505         } else {
506           // This is an infinite loop and we set trip_count to -1.
507           return ParsedWhileLoop{
508               ParsedStaticWhileLoop{/*trip_count=*/-1,
509                                     /*induction_var_index=*/loop_cond_var_index,
510                                     /*induction_var_init_value=*/0,
511                                     /*step_size=*/0,
512                                     /*loop_bound=*/1}};
513         }
514       }
515     }
516     return std::nullopt;
517   }
518   CHECK_EQ(loop_comparison_or_noop->index(), 0);
519   WhileCondComparison loop_comparison =
520       std::get<WhileCondComparison>(*loop_comparison_or_noop);
521   CHECK(loop_comparison.lhs.IsValid() && loop_comparison.rhs.IsValid());
522 
523   // If the while loop condition comparison's both sides take an init value
524   // from the while loop's parent computation's parameter, the loop is dynamic.
525   if (while_operand->opcode() == HloOpcode::kParameter) {
526     if (loop_comparison.lhs.param_index.has_value() ||
527         loop_comparison.rhs.param_index.has_value()) {
528       return kParsedDynamicWhileLoop;
529     }
530   }
531 
532   // We can't handle the case when the while loop argument is not a Tuple
533   // instruction.
534   if (while_operand->opcode() != HloOpcode::kTuple) {
535     return std::nullopt;
536   }
537 
538   // If loop cond comparison LHS does not have a value defined inside the loop
539   // cond computation, try to evaluate its init value inside the while loop's
540   // parent computation.
541   if (!loop_comparison.lhs.value.has_value()) {
542     std::optional<DynamicOrStaticValue> lhs_init_value =
543         EvaluateWhileLoopParamInitValue(while_operand,
544                                         *loop_comparison.lhs.param_index);
545     if (lhs_init_value.has_value()) {
546       if (lhs_init_value->is_dynamic()) {
547         return kParsedDynamicWhileLoop;
548       } else {
549         loop_comparison.lhs.value = *(lhs_init_value->static_value);
550       }
551     } else {
552       return std::nullopt;
553     }
554   }
555 
556   // If loop cond comparison RHS does not have a value defined inside the loop
557   // cond computation, try to evaluate its init value inside the while loop's
558   // parent computation.
559   if (!loop_comparison.rhs.value.has_value()) {
560     std::optional<DynamicOrStaticValue> rhs_init_value =
561         EvaluateWhileLoopParamInitValue(while_operand,
562                                         *loop_comparison.rhs.param_index);
563     if (rhs_init_value.has_value()) {
564       if (rhs_init_value->is_dynamic()) {
565         return kParsedDynamicWhileLoop;
566       } else {
567         loop_comparison.rhs.value = *(rhs_init_value->static_value);
568       }
569     } else {
570       return std::nullopt;
571     }
572   }
573 
574   // We have either successfully evaluated the init value for both LHS and RHS
575   // or have returned as dynamic loop or failure.
576   CHECK(loop_comparison.lhs.value.has_value());
577   CHECK(loop_comparison.rhs.value.has_value());
578 
579   if (loop_comparison.lhs.param_index.has_value()) {
580     VLOG(3) << __func__ << " lhs index: " << *loop_comparison.lhs.param_index;
581   }
582 
583   VLOG(3) << __func__ << " lhs bound: " << *loop_comparison.lhs.value;
584 
585   if (loop_comparison.rhs.param_index.has_value()) {
586     VLOG(3) << __func__ << " rhs index: " << *loop_comparison.rhs.param_index;
587   }
588 
589   VLOG(3) << __func__ << " rhs bound: " << *loop_comparison.rhs.value;
590 
591   // Check whether LHS is the loop induction var.
592   std::optional<int64_t> lhs_induction_var_update;
593   if (loop_comparison.lhs.param_index.has_value()) {
594     lhs_induction_var_update = PatternMatchInductionVarUpdate(
595         while_body->root_instruction(), *loop_comparison.lhs.param_index);
596   }
597 
598   // Check whether LHS is the loop induction var.
599   std::optional<int64_t> rhs_induction_var_update;
600   if (loop_comparison.rhs.param_index.has_value()) {
601     rhs_induction_var_update = PatternMatchInductionVarUpdate(
602         while_body->root_instruction(), *loop_comparison.rhs.param_index);
603   }
604 
605   // Lhs is the induction variable.
606   if (lhs_induction_var_update.has_value()) {
607     // We cannot handle the case when both LHS and RHS are updated inside
608     // the loop body.
609     if (rhs_induction_var_update.has_value() &&
610         *rhs_induction_var_update != 0) {
611       return std::nullopt;
612     }
613     if (*lhs_induction_var_update > 0 &&
614         (loop_comparison.comparson_direction == Comparison::Direction::kLt ||
615          loop_comparison.comparson_direction == Comparison::Direction::kLe)) {
616       int64_t trip_count =
617           (*loop_comparison.rhs.value - *loop_comparison.lhs.value - 1) /
618               *lhs_induction_var_update +
619           1;
620       // Additional logic to deal with Equal comparison.
621       if (loop_comparison.comparson_direction == Comparison::Direction::kLe &&
622           (*loop_comparison.rhs.value - *loop_comparison.lhs.value) %
623                   *lhs_induction_var_update ==
624               0) {
625         trip_count += 1;
626       }
627       return ParsedWhileLoop{ParsedStaticWhileLoop{
628           /*trip_count=*/trip_count,
629           /*induction_var_index=*/*loop_comparison.lhs.param_index,
630           /*induction_var_init_value=*/*loop_comparison.lhs.value,
631           /*step_size=*/*lhs_induction_var_update,
632           /*loop_bound=*/*loop_comparison.rhs.value}};
633     } else if (*lhs_induction_var_update < 0 &&
634                (loop_comparison.comparson_direction ==
635                     Comparison::Direction::kGt ||
636                 loop_comparison.comparson_direction ==
637                     Comparison::Direction::kGe)) {
638       int trip_count =
639           (*loop_comparison.lhs.value - *loop_comparison.rhs.value - 1) /
640               *lhs_induction_var_update +
641           1;
642       if (loop_comparison.comparson_direction == Comparison::Direction::kGe &&
643           (*loop_comparison.lhs.value - *loop_comparison.rhs.value) %
644                   *lhs_induction_var_update ==
645               0) {
646         trip_count += 1;
647       }
648       return ParsedWhileLoop{ParsedStaticWhileLoop{
649           /*trip_count=*/trip_count,
650           /*induction_var_index=*/*(loop_comparison.lhs.param_index),
651           /*induction_var_init_value=*/*(loop_comparison.lhs.value),
652           /*step_size=*/-*lhs_induction_var_update,
653           /*loop_bound=*/*(loop_comparison.rhs.value)}};
654     }
655     return std::nullopt;
656   }
657   // Rhs is the induction variable.
658   if (rhs_induction_var_update.has_value()) {
659     // We cannot handle the case when both LHS and RHS are updated inside
660     // the loop body.
661     if (lhs_induction_var_update.has_value() &&
662         *lhs_induction_var_update == 0) {
663       return std::nullopt;
664     }
665     if (*rhs_induction_var_update > 0 &&
666         (loop_comparison.comparson_direction == Comparison::Direction::kGt ||
667          loop_comparison.comparson_direction == Comparison::Direction::kGe)) {
668       int trip_count =
669           (*loop_comparison.lhs.value - *loop_comparison.rhs.value - 1) /
670               *rhs_induction_var_update +
671           1;
672       if (loop_comparison.comparson_direction == Comparison::Direction::kGe &&
673           (*loop_comparison.lhs.value - *loop_comparison.rhs.value) %
674                   *rhs_induction_var_update ==
675               0) {
676         trip_count += 1;
677       }
678       return ParsedWhileLoop{ParsedStaticWhileLoop{
679           /*trip_count=*/trip_count,
680           /*induction_var_index=*/*(loop_comparison.rhs.param_index),
681           /*induction_var_init_value=*/*(loop_comparison.rhs.value),
682           /*step_size=*/*rhs_induction_var_update,
683           /*loop_bound=*/*(loop_comparison.lhs.value)}};
684     } else if (*rhs_induction_var_update < 0 &&
685                (loop_comparison.comparson_direction ==
686                     Comparison::Direction::kLt ||
687                 loop_comparison.comparson_direction ==
688                     Comparison::Direction::kLe)) {
689       int trip_count =
690           (*loop_comparison.rhs.value - *loop_comparison.lhs.value - 1) /
691               *rhs_induction_var_update +
692           1;
693       if (loop_comparison.comparson_direction == Comparison::Direction::kLe &&
694           (*loop_comparison.rhs.value - *loop_comparison.lhs.value) %
695                   *rhs_induction_var_update ==
696               0) {
697         trip_count += 1;
698       }
699       return ParsedWhileLoop{ParsedStaticWhileLoop{
700           /*trip_count=*/trip_count,
701           /*induction_var_index=*/*(loop_comparison.rhs.param_index),
702           /*induction_var_init_value=*/*(loop_comparison.rhs.value),
703           /*step_size=*/-*rhs_induction_var_update,
704           /*loop_bound=*/*(loop_comparison.lhs.value)}};
705     }
706     return std::nullopt;
707   }
708   return std::nullopt;
709 }
710 
711 // Note that unsupported types by the typed visitor does not necessarily imply
712 // the non-typed HloEvaluator (parent evaluator) would not support them either
713 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
714 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
715 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64_t max_loop_iterations)716 HloEvaluator::HloEvaluator(int64_t max_loop_iterations)
717     : max_loop_iterations_(max_loop_iterations) {
718   typed_visitors_[PRED] =
719       std::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
720   typed_visitors_[U8] =
721       std::make_unique<HloEvaluatorTypedVisitor<uint8_t>>(this);
722   typed_visitors_[U16] =
723       std::make_unique<HloEvaluatorTypedVisitor<uint16_t>>(this);
724   typed_visitors_[U32] =
725       std::make_unique<HloEvaluatorTypedVisitor<uint32_t>>(this);
726   typed_visitors_[U64] =
727       std::make_unique<HloEvaluatorTypedVisitor<uint64_t>>(this);
728   typed_visitors_[S8] =
729       std::make_unique<HloEvaluatorTypedVisitor<int8_t>>(this);
730   typed_visitors_[S16] =
731       std::make_unique<HloEvaluatorTypedVisitor<int16_t>>(this);
732   typed_visitors_[S32] =
733       std::make_unique<HloEvaluatorTypedVisitor<int32_t>>(this);
734   typed_visitors_[S64] =
735       std::make_unique<HloEvaluatorTypedVisitor<int64_t>>(this);
736   typed_visitors_[F16] =
737       std::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
738   typed_visitors_[F32] =
739       std::make_unique<HloEvaluatorTypedVisitor<float>>(this);
740   typed_visitors_[F64] =
741       std::make_unique<HloEvaluatorTypedVisitor<double>>(this);
742   typed_visitors_[C64] =
743       std::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
744   typed_visitors_[C128] =
745       std::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
746 
747   // Most of the evaluator computations we use don't support BF16 (e.g.,
748   // std::ceil, std::tanh). To make evaluator work with BF16, we set all
749   // elementwise computations to be done in F32 and do BF16<->F32 conversion
750   // around the input and the output of the computations.
751   typed_visitors_[BF16] =
752       std::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
753 
754   typed_visitors_[TUPLE] =
755       std::make_unique<FunctionVisitor>([](HloInstruction*) {
756         return Unimplemented(
757             "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
758       });
759   typed_visitors_[OPAQUE_TYPE] =
760       std::make_unique<FunctionVisitor>([](HloInstruction*) {
761         return Unimplemented(
762             "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE.");
763       });
764   typed_visitors_[TOKEN] =
765       std::make_unique<FunctionVisitor>([](HloInstruction*) {
766         return Unimplemented(
767             "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
768       });
769 }
770 
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)771 StatusOr<Literal> HloEvaluator::Evaluate(
772     const HloComputation& computation,
773     absl::Span<const Literal* const> arg_literals) {
774   CHECK(computation.parent() != nullptr);
775   XLA_VLOG_LINES(
776       2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
777 
778   if (arg_literals.size() != computation.num_parameters()) {
779     return InvalidArgument(
780         "Expected %d argument%s, but got %d.", computation.num_parameters(),
781         computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
782   }
783   for (int64_t i = 0; i < arg_literals.size(); ++i) {
784     const auto& computation_shape =
785         computation.parameter_instruction(i)->shape();
786     const auto& arg_shape = arg_literals[i]->shape();
787     if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
788                                                    arg_shape)) {
789       return InvalidArgument(
790           "Shape mismatch at parameter %d. Computation expected %s, but arg "
791           "was %s.",
792           i, ShapeUtil::HumanStringWithLayout(computation_shape),
793           ShapeUtil::HumanStringWithLayout(arg_shape));
794     }
795   }
796 
797   evaluated_.clear();
798   arg_literals_.clear();
799   for (const auto& literal_ptr : arg_literals) {
800     arg_literals_.push_back(&*literal_ptr);
801   }
802 
803   // Re-seed RNG, either from the configuration's seed or a monotonic
804   // per-evaluator seed (which prevents two evaluators from returning the same
805   // random sequence).
806   if (computation.parent()->config().seed()) {
807     seed_ = computation.parent()->config().seed();
808   } else {
809     // Start global_seed at a (true) random value.
810     static std::atomic<uint64_t> global_seed{std::random_device()()};
811     seed_ = global_seed.fetch_add(1);
812   }
813   engine_.seed(seed_);
814 
815   TF_RETURN_IF_ERROR(computation.Accept(this));
816   const Literal& result =
817       GetEvaluatedLiteralFor(computation.root_instruction());
818   if (VLOG_IS_ON(100)) {
819     for (const HloInstruction* instr : computation.instructions()) {
820       VLOG(100) << instr->name() << " = " << GetEvaluatedLiteralFor(instr);
821     }
822   }
823   if (!result.IsKnown()) {
824     return MakeEvalErrorDueToParamOrInfeed(*computation.root_instruction());
825   }
826   return result.Clone();
827 }
828 
Evaluate(HloInstruction * instruction,bool recursively_evaluate_nonconstant_operands)829 StatusOr<Literal> HloEvaluator::Evaluate(
830     HloInstruction* instruction,
831     bool recursively_evaluate_nonconstant_operands) {
832   arg_literals_.clear();
833   evaluated_.clear();
834   auto enable_partial_evaluation_cleanup =
835       absl::MakeCleanup([this] { enable_partial_evaluation_ = false; });
836   enable_partial_evaluation_ = recursively_evaluate_nonconstant_operands;
837   TF_RETURN_IF_ERROR(
838       EvaluateInternal(instruction, /*shape_index=*/{},
839                        recursively_evaluate_nonconstant_operands));
840   const Literal& result = GetEvaluatedLiteralFor(instruction);
841   if (!result.IsKnown()) {
842     return MakeEvalErrorDueToParamOrInfeed(*instruction);
843   }
844   return result.Clone();
845 }
846 
TryEvaluate(HloInstruction * instruction,Literal * result,bool recursively_evaluate_nonconstant_operands)847 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result,
848                                bool recursively_evaluate_nonconstant_operands) {
849   CHECK(result != nullptr);
850   auto result_or =
851       Evaluate(instruction, recursively_evaluate_nonconstant_operands);
852   if (!result_or.ok()) {
853     VLOG(1) << "TryEvaluate failed:" << result_or.status();
854     return false;
855   }
856 
857   *result = std::move(result_or).value();
858   return true;
859 }
860 
EvaluateWithSubstitutions(const HloInstruction * instruction,const absl::flat_hash_map<const HloInstruction *,const Literal * > & substitutions)861 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
862     const HloInstruction* instruction,
863     const absl::flat_hash_map<const HloInstruction*, const Literal*>&
864         substitutions) {
865   std::vector<std::unique_ptr<HloInstruction>> owned_operands;
866   for (const HloInstruction* operand : instruction->operands()) {
867     auto it = substitutions.find(operand);
868     if (it == substitutions.end()) {
869       owned_operands.push_back(operand->Clone());
870     } else {
871       owned_operands.push_back(
872           HloInstruction::CreateConstant(it->second->Clone()));
873     }
874   }
875 
876   std::vector<HloInstruction*> operands;
877   operands.reserve(owned_operands.size());
878   for (auto& operand : owned_operands) {
879     operands.push_back(operand.get());
880   }
881 
882   std::unique_ptr<HloInstruction> cloned_instruction =
883       instruction->CloneWithNewOperands(instruction->shape(), operands);
884   auto result = Evaluate(cloned_instruction.get());
885 
886   return result;
887 }
888 
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)889 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
890     HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
891   std::unique_ptr<HloInstruction> lhs_instr =
892       HloInstruction::CreateConstant(lhs.Clone());
893   std::unique_ptr<HloInstruction> rhs_instr =
894       HloInstruction::CreateConstant(rhs.Clone());
895 
896   std::unique_ptr<HloInstruction> cloned_instruction =
897       HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
898                                    rhs_instr.get());
899   auto result = Evaluate(cloned_instruction.get());
900 
901   return result;
902 }
903 
EvaluateElementwiseTernaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs,const Literal & ehs)904 StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
905     HloOpcode opcode, const Literal& lhs, const Literal& rhs,
906     const Literal& ehs) {
907   std::unique_ptr<HloInstruction> lhs_instr =
908       HloInstruction::CreateConstant(lhs.Clone());
909   std::unique_ptr<HloInstruction> rhs_instr =
910       HloInstruction::CreateConstant(rhs.Clone());
911   std::unique_ptr<HloInstruction> ehs_instr =
912       HloInstruction::CreateConstant(ehs.Clone());
913   TF_ASSIGN_OR_RETURN(auto output_shape,
914                       ShapeInference::InferTernaryOpShape(
915                           opcode, lhs.shape(), rhs.shape(), ehs.shape()));
916   std::unique_ptr<HloInstruction> cloned_instruction =
917       HloInstruction::CreateTernary(output_shape, opcode, lhs_instr.get(),
918                                     rhs_instr.get(), ehs_instr.get());
919   return Evaluate(cloned_instruction.get());
920 }
921 
EvaluateElementwiseCompareOp(ComparisonDirection direction,const Literal & lhs,const Literal & rhs)922 StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
923     ComparisonDirection direction, const Literal& lhs, const Literal& rhs) {
924   std::unique_ptr<HloInstruction> lhs_instr =
925       HloInstruction::CreateConstant(lhs.Clone());
926   std::unique_ptr<HloInstruction> rhs_instr =
927       HloInstruction::CreateConstant(rhs.Clone());
928 
929   std::unique_ptr<HloInstruction> cloned_instruction =
930       HloInstruction::CreateCompare(
931           ShapeUtil::ChangeElementType(lhs.shape(), PRED), lhs_instr.get(),
932           rhs_instr.get(), direction);
933   auto result = Evaluate(cloned_instruction.get());
934 
935   return result;
936 }
937 
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)938 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
939     HloOpcode opcode, const Literal& operand) {
940   std::unique_ptr<HloInstruction> operand_instr =
941       HloInstruction::CreateConstant(operand.Clone());
942 
943   TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferUnaryOpShape(
944                                                 opcode, operand.shape()));
945   std::unique_ptr<HloInstruction> cloned_instruction =
946       HloInstruction::CreateUnary(inferred_shape, opcode, operand_instr.get());
947   auto result = Evaluate(cloned_instruction.get());
948 
949   return result;
950 }
951 
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)952 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
953     const DotDimensionNumbers& dim_numbers,
954     const PrecisionConfig& precision_config, const Literal& lhs,
955     const Literal& rhs) {
956   std::unique_ptr<HloInstruction> lhs_instr =
957       HloInstruction::CreateConstant(lhs.Clone());
958   std::unique_ptr<HloInstruction> rhs_instr =
959       HloInstruction::CreateConstant(rhs.Clone());
960 
961   TF_ASSIGN_OR_RETURN(
962       Shape dot_shape,
963       ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers,
964                                       /*preferred_element_type=*/std::nullopt));
965 
966   std::unique_ptr<HloInstruction> cloned_instruction =
967       HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
968                                 dim_numbers, precision_config);
969   return Evaluate(cloned_instruction.get());
970 }
971 
EvaluateInternal(HloInstruction * instruction,const ShapeIndex & shape_index,bool recursively_evaluate_nonconstant_operands)972 Status HloEvaluator::EvaluateInternal(
973     HloInstruction* instruction, const ShapeIndex& shape_index,
974     bool recursively_evaluate_nonconstant_operands) {
975   // Don't need to evaluate this instruction again if it has already been
976   // evaluated.
977   if (IsAlreadyEvaluated(instruction, shape_index)) {
978     return OkStatus();
979   }
980 
981   if (!recursively_evaluate_nonconstant_operands) {
982     if (!hlo_query::AllOperandsAreConstants(*instruction)) {
983       return tensorflow::errors::FailedPrecondition(
984           "Not all operands are constants.");
985     }
986   } else {
987     if (instruction->opcode() == HloOpcode::kGetTupleElement) {
988       ShapeIndex new_shape_index = shape_index;
989       new_shape_index.push_front(instruction->tuple_index());
990       TF_RETURN_IF_ERROR(
991           EvaluateInternal(instruction->mutable_operand(0), new_shape_index,
992                            /*recursively_evaluate_nonconstant_operands=*/true));
993     } else if (instruction->opcode() == HloOpcode::kTuple &&
994                !shape_index.empty()) {
995       ShapeIndex new_shape_index = shape_index;
996       int64_t tuple_index = new_shape_index.front();
997       new_shape_index.pop_front();
998       TF_RETURN_IF_ERROR(EvaluateInternal(
999           instruction->mutable_operand(tuple_index), new_shape_index,
1000           /*recursively_evaluate_nonconstant_operands=*/true));
1001     } else {
1002       for (HloInstruction* operand : instruction->operands()) {
1003         TF_RETURN_IF_ERROR(EvaluateInternal(
1004             operand, /*shape_index=*/{},
1005             /*recursively_evaluate_nonconstant_operands=*/true));
1006         // Except for the above and following cases, we do not support handling
1007         // unknown operands for other HLOs. So mark the result as unknown.
1008         if ((!GetEvaluatedLiteralFor(operand).IsKnown() &&
1009              instruction->opcode() != HloOpcode::kCopy &&
1010              instruction->opcode() != HloOpcode::kCopyStart &&
1011              instruction->opcode() != HloOpcode::kCopyDone &&
1012              instruction->opcode() != HloOpcode::kAsyncStart &&
1013              instruction->opcode() != HloOpcode::kAsyncUpdate &&
1014              instruction->opcode() != HloOpcode::kAsyncDone &&
1015              instruction->opcode() != HloOpcode::kWhile)) {
1016           evaluated_[instruction] =
1017               Literal::CreateFromShapeWithUnknownLeafArrays(
1018                   instruction->shape());
1019           return OkStatus();
1020         }
1021       }
1022     }
1023   }
1024   visitor_shape_index_ = shape_index;
1025   TF_RETURN_IF_ERROR(Preprocess(instruction));
1026   TF_RETURN_IF_ERROR(instruction->Visit(this));
1027   TF_RETURN_IF_ERROR(Postprocess(instruction));
1028   return OkStatus();
1029 }
1030 
HandleBitcast(HloInstruction * bitcast)1031 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
1032   const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
1033   Literal result(bitcast->shape());
1034   // Bitcast output is allowed to be smaller than the input if the backend-
1035   // specific buffer sizes for the input and output are the same. Since the HLO
1036   // evaluator doesn't have access to the backend-specific shape size function,
1037   // assume it's OK to bitcast if output <= input.
1038   TF_RET_CHECK(operand_literal.size_bytes() >= result.size_bytes());
1039   memcpy(result.untyped_data(), operand_literal.untyped_data(),
1040          result.size_bytes());
1041   evaluated_[bitcast] = std::move(result);
1042   return OkStatus();
1043 }
1044 
HandleGetDimensionSize(HloInstruction * get_dimension_size)1045 Status HloEvaluator::HandleGetDimensionSize(
1046     HloInstruction* get_dimension_size) {
1047   HloInstruction* operand = get_dimension_size->mutable_operand(0);
1048   int64_t dim = get_dimension_size->dimension();
1049   if (dynamic_dimension_inference_ == nullptr) {
1050     return InvalidArgument(
1051         "Evaluator cannot evaluate get_dimension_size without "
1052         "set_dynamic_dimension_inference.");
1053   }
1054   HloInstruction* dynamic_size =
1055       dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
1056   if (dynamic_size != nullptr) {
1057     evaluated_[get_dimension_size] =
1058         GetEvaluatedLiteralFor(dynamic_size).Clone();
1059     return OkStatus();
1060   }
1061 
1062   const Shape& shape = get_dimension_size->operand(0)->shape();
1063   Literal output(ShapeUtil::MakeShape(S32, {}));
1064   output.PopulateWithValue(
1065       static_cast<int32_t>(shape.dimensions(get_dimension_size->dimension())));
1066   evaluated_[get_dimension_size] = std::move(output);
1067   return OkStatus();
1068 }
1069 
HandleSetDimensionSize(HloInstruction * set_dimension_size)1070 Status HloEvaluator::HandleSetDimensionSize(
1071     HloInstruction* set_dimension_size) {
1072   const Literal& operand_literal =
1073       GetEvaluatedLiteralFor(set_dimension_size->operand(0));
1074   Literal result(set_dimension_size->shape());
1075   memcpy(result.untyped_data(), operand_literal.untyped_data(),
1076          operand_literal.size_bytes());
1077   const Literal& size_literal =
1078       GetEvaluatedLiteralFor(set_dimension_size->operand(1));
1079   result.SetDynamicSize(set_dimension_size->dimension(),
1080                         size_literal.Get<int32_t>({}));
1081   evaluated_[set_dimension_size] = std::move(result);
1082   return OkStatus();
1083 }
1084 
HandleParameter(HloInstruction * parameter)1085 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
1086   if (arg_literals_.empty()) {
1087     if (!enable_partial_evaluation_) {
1088       return tensorflow::errors::FailedPrecondition(
1089           "Failed to evaluate instruction since its operands are unknown "
1090           "or undetermined and partial evaluation is not enabled.");
1091     }
1092     evaluated_[parameter] =
1093         Literal::CreateFromShapeWithUnknownLeafArrays(parameter->shape());
1094     return OkStatus();
1095   }
1096 
1097   // Nothing to do other than sanity checks. Parameters' values are stored in
1098   // arg_literals_.
1099   CHECK_LT(parameter->parameter_number(), arg_literals_.size());
1100 
1101 #ifndef NDEBUG
1102   const Literal* input_literal = arg_literals_[parameter->parameter_number()];
1103   VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
1104   DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
1105                                                    input_literal->shape()))
1106       << "parameter shape is: "
1107       << ShapeUtil::HumanStringWithLayout(parameter->shape())
1108       << ", but input literal shape is: "
1109       << ShapeUtil::HumanStringWithLayout(input_literal->shape());
1110 #endif
1111 
1112   return OkStatus();
1113 }
1114 
HandleInfeed(HloInstruction * infeed)1115 Status HloEvaluator::HandleInfeed(HloInstruction* infeed) {
1116   if (!enable_partial_evaluation_) {
1117     return tensorflow::errors::FailedPrecondition(
1118         "Failed to evaluate instruction since its operands are unknown "
1119         "or undetermined and partial evaluation is not enabled.");
1120   }
1121   evaluated_[infeed] =
1122       Literal::CreateFromShapeWithUnknownLeafArrays(infeed->shape());
1123   return OkStatus();
1124 }
1125 
HandleConstant(HloInstruction *)1126 Status HloEvaluator::HandleConstant(HloInstruction*) { return OkStatus(); }
1127 
HandleReshape(HloInstruction * reshape)1128 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
1129   TF_ASSIGN_OR_RETURN(evaluated_[reshape],
1130                       GetEvaluatedLiteralFor(reshape->operand(0))
1131                           .Reshape(reshape->shape().dimensions()));
1132   return OkStatus();
1133 }
1134 
HandleTranspose(HloInstruction * transpose)1135 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
1136   evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
1137                               .Transpose(transpose->dimensions());
1138   return OkStatus();
1139 }
1140 
HandleConcatenate(HloInstruction * concatenate)1141 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
1142   absl::Span<HloInstruction* const> operands(concatenate->operands());
1143   // The result concatenate dimension is going to be the sum of all
1144   // concatenate dimensions of the operands taking part of the operation.
1145   const Shape& reference_shape = operands[0]->shape();
1146   CHECK(reference_shape.IsArray());
1147   const int64_t rank = reference_shape.rank();
1148   const int64_t concat_dim = concatenate->dimensions()[0];
1149   CHECK_GE(concat_dim, 0);
1150   CHECK_LT(concat_dim, rank);
1151 
1152   DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
1153                                     reference_shape.dimensions().end());
1154 
1155   for (int64_t i = 1; i < operands.size(); ++i) {
1156     const Shape& operand_shape = operands[i]->shape();
1157     CHECK(operand_shape.IsArray());
1158     // Accumulate the concat dimension from all tensors taking part to the
1159     // operation.
1160     concat_dimensions[concat_dim] +=
1161         ShapeUtil::GetDimension(operand_shape, concat_dim);
1162   }
1163 
1164   auto result_literal = LiteralUtil::CreateFromDimensions(
1165       reference_shape.element_type(), concat_dimensions);
1166   DimensionVector source_indices(rank, 0);
1167   DimensionVector dest_indices(concat_dimensions.size(), 0);
1168 
1169   for (auto operand : operands) {
1170     const Shape& operand_shape = operand->shape();
1171     TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
1172         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
1173         operand_shape.dimensions()));
1174     dest_indices[concat_dim] +=
1175         ShapeUtil::GetDimension(operand_shape, concat_dim);
1176   }
1177 
1178   evaluated_[concatenate] = std::move(result_literal);
1179   return OkStatus();
1180 }
1181 
HandleIsFinite(HloInstruction * is_finite)1182 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
1183   auto operand = is_finite->operand(0);
1184   auto elem_ty = operand->shape().element_type();
1185   switch (elem_ty) {
1186     case PRED:
1187     case TUPLE:
1188     case OPAQUE_TYPE:
1189     case TOKEN:
1190     case S8:
1191     case S16:
1192     case S32:
1193     case S64:
1194     case U8:
1195     case U16:
1196     case U32:
1197     case U64:
1198     case C64:
1199     case C128:
1200     // Explicitly enumerate all types in this switch so that when we add a new
1201     // type, we'll get a compile error here.
1202     case PRIMITIVE_TYPE_INVALID:
1203     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
1204     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
1205       return InvalidArgument(
1206           "expected element type in shape to be floating point, but "
1207           "got: %s",
1208           PrimitiveType_Name(elem_ty));
1209 
1210     case F16: {
1211       auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
1212           is_finite,
1213           [](Eigen::half elem_operand) {
1214             return std::isfinite(static_cast<float>(elem_operand));
1215           },
1216           GetEvaluatedLiteralFor(operand));
1217       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1218       break;
1219     }
1220     case BF16: {
1221       auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
1222           is_finite,
1223           [](bfloat16 elem_operand) {
1224             return std::isfinite(static_cast<float>(elem_operand));
1225           },
1226           GetEvaluatedLiteralFor(operand));
1227       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1228       break;
1229     }
1230     case F32: {
1231       auto result_or = ElementWiseUnaryOpImpl<bool, float>(
1232           is_finite,
1233           [](float elem_operand) { return std::isfinite(elem_operand); },
1234           GetEvaluatedLiteralFor(operand));
1235       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1236       break;
1237     }
1238     case F64: {
1239       auto result_or = ElementWiseUnaryOpImpl<bool, double>(
1240           is_finite,
1241           [](double elem_operand) { return std::isfinite(elem_operand); },
1242           GetEvaluatedLiteralFor(operand));
1243       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1244       break;
1245     }
1246   }
1247 
1248   return OkStatus();
1249 }
1250 
HandleReal(HloInstruction * real)1251 Status HloEvaluator::HandleReal(HloInstruction* real) {
1252   auto operand = real->operand(0);
1253   switch (operand->shape().element_type()) {
1254     case BF16: {
1255       auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
1256           real, [](bfloat16 elem_operand) { return elem_operand; },
1257           GetEvaluatedLiteralFor(operand));
1258       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1259       break;
1260     }
1261     case C64: {
1262       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
1263           real, [](complex64 elem_operand) { return std::real(elem_operand); },
1264           GetEvaluatedLiteralFor(operand));
1265       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1266       break;
1267     }
1268     case C128: {
1269       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
1270           real, [](complex128 elem_operand) { return std::real(elem_operand); },
1271           GetEvaluatedLiteralFor(operand));
1272       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1273       break;
1274     }
1275     case F16: {
1276       auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
1277           real, [](Eigen::half elem_operand) { return elem_operand; },
1278           GetEvaluatedLiteralFor(operand));
1279       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1280       break;
1281     }
1282     case F32: {
1283       auto result_or = ElementWiseUnaryOpImpl<float, float>(
1284           real, [](float elem_operand) { return elem_operand; },
1285           GetEvaluatedLiteralFor(operand));
1286       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1287       break;
1288     }
1289     case F64: {
1290       auto result_or = ElementWiseUnaryOpImpl<double, double>(
1291           real, [](double elem_operand) { return elem_operand; },
1292           GetEvaluatedLiteralFor(operand));
1293       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1294       break;
1295     }
1296     default:
1297       LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
1298                  << PrimitiveType_Name(operand->shape().element_type());
1299   }
1300 
1301   return OkStatus();
1302 }
1303 
HandleImag(HloInstruction * imag)1304 Status HloEvaluator::HandleImag(HloInstruction* imag) {
1305   auto operand = imag->operand(0);
1306   switch (operand->shape().element_type()) {
1307     case BF16: {
1308       auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
1309           imag, [](bfloat16 elem_operand) { return bfloat16(0); },
1310           GetEvaluatedLiteralFor(imag->operand(0)));
1311 
1312       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1313       break;
1314     }
1315     case C64: {
1316       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
1317           imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
1318           GetEvaluatedLiteralFor(imag->operand(0)));
1319 
1320       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1321       break;
1322     }
1323     case C128: {
1324       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
1325           imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
1326           GetEvaluatedLiteralFor(imag->operand(0)));
1327 
1328       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1329       break;
1330     }
1331     case F16: {
1332       auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
1333           imag, [](Eigen::half elem_operand) { return Eigen::half(0); },
1334           GetEvaluatedLiteralFor(imag->operand(0)));
1335 
1336       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1337       break;
1338     }
1339     case F32: {
1340       auto result_or = ElementWiseUnaryOpImpl<float, float>(
1341           imag, [](float elem_operand) { return 0; },
1342           GetEvaluatedLiteralFor(imag->operand(0)));
1343 
1344       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1345       break;
1346     }
1347     case F64: {
1348       auto result_or = ElementWiseUnaryOpImpl<double, double>(
1349           imag, [](double elem_operand) { return 0; },
1350           GetEvaluatedLiteralFor(imag->operand(0)));
1351 
1352       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1353       break;
1354     }
1355     default:
1356       LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
1357                  << PrimitiveType_Name(operand->shape().element_type());
1358   }
1359 
1360   return OkStatus();
1361 }
1362 
HandleComplex(HloInstruction * complex)1363 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
1364   const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
1365   const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
1366   TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
1367 
1368   Literal result(complex->shape());
1369   switch (complex->shape().element_type()) {
1370     case C64: {
1371       TF_RETURN_IF_ERROR(result.Populate<complex64>(
1372           [&](absl::Span<const int64_t> multi_index) {
1373             return std::complex<float>(real.Get<float>(multi_index),
1374                                        imag.Get<float>(multi_index));
1375           }));
1376       break;
1377     }
1378     case C128: {
1379       TF_RETURN_IF_ERROR(result.Populate<complex128>(
1380           [&](absl::Span<const int64_t> multi_index) {
1381             return std::complex<double>(real.Get<double>(multi_index),
1382                                         imag.Get<double>(multi_index));
1383           }));
1384       break;
1385     }
1386     default:
1387       LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
1388                  << PrimitiveType_Name(complex->shape().element_type());
1389   }
1390 
1391   evaluated_[complex] = std::move(result);
1392   return OkStatus();
1393 }
1394 
HandleCompare(HloInstruction * compare)1395 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
1396   ComparisonDirection direction = compare->comparison_direction();
1397   auto lhs = compare->operand(0);
1398   auto rhs = compare->operand(1);
1399   DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
1400          ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
1401 
1402   TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
1403 
1404   const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
1405   const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
1406 
1407   // Note here we switch on the operand's type.
1408   switch (lhs->shape().element_type()) {
1409     case PRED: {
1410       TF_ASSIGN_OR_RETURN(
1411           evaluated_[compare],
1412           Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
1413     } break;
1414     case U8: {
1415       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1416                           Compare<uint8_t>(compare->shape(), direction,
1417                                            lhs_literal, rhs_literal));
1418     } break;
1419     case U16: {
1420       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1421                           Compare<uint16_t>(compare->shape(), direction,
1422                                             lhs_literal, rhs_literal));
1423     } break;
1424     case U32: {
1425       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1426                           Compare<uint32_t>(compare->shape(), direction,
1427                                             lhs_literal, rhs_literal));
1428     } break;
1429     case U64: {
1430       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1431                           Compare<uint64_t>(compare->shape(), direction,
1432                                             lhs_literal, rhs_literal));
1433     } break;
1434     case S8: {
1435       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1436                           Compare<int8_t>(compare->shape(), direction,
1437                                           lhs_literal, rhs_literal));
1438     } break;
1439     case S16: {
1440       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1441                           Compare<int16_t>(compare->shape(), direction,
1442                                            lhs_literal, rhs_literal));
1443     } break;
1444     case S32: {
1445       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1446                           Compare<int32_t>(compare->shape(), direction,
1447                                            lhs_literal, rhs_literal));
1448     } break;
1449     case S64: {
1450       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1451                           Compare<int64_t>(compare->shape(), direction,
1452                                            lhs_literal, rhs_literal));
1453     } break;
1454     case F16: {
1455       TF_ASSIGN_OR_RETURN(
1456           evaluated_[compare],
1457           Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
1458     } break;
1459     case BF16: {
1460       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1461                           Compare<bfloat16>(compare->shape(), direction,
1462                                             lhs_literal, rhs_literal));
1463     } break;
1464     case F32: {
1465       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1466                           Compare<float>(compare->shape(), direction,
1467                                          lhs_literal, rhs_literal));
1468     } break;
1469     case F64: {
1470       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1471                           Compare<double>(compare->shape(), direction,
1472                                           lhs_literal, rhs_literal));
1473     } break;
1474     case C64: {
1475       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1476                           Compare<complex64>(compare->shape(), direction,
1477                                              lhs_literal, rhs_literal));
1478     } break;
1479     case C128: {
1480       TF_ASSIGN_OR_RETURN(evaluated_[compare],
1481                           Compare<complex128>(compare->shape(), direction,
1482                                               lhs_literal, rhs_literal));
1483     } break;
1484     default:
1485       LOG(FATAL) << "HandleCompare: unknown primitive type: "
1486                  << PrimitiveType_Name(lhs->shape().element_type());
1487   }
1488 
1489   return OkStatus();
1490 }
1491 
HandleTuple(HloInstruction * tuple)1492 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
1493   std::vector<const Literal*> operand_literals;
1494   std::vector<Literal> operand_literal_values;
1495   if (!visitor_shape_index_.empty()) {
1496     // We only need to evaluate tuple at visitor_shape_index_. The other
1497     // operands might not have been evaluated, so mark the other operands as
1498     // undetermined.
1499     int64_t tuple_index = visitor_shape_index_.front();
1500     operand_literal_values.resize(tuple->operand_count());
1501     for (int operand_index = 0; operand_index < tuple->operand_count();
1502          ++operand_index) {
1503       if (operand_index == tuple_index) {
1504         operand_literals.push_back(
1505             &GetEvaluatedLiteralFor(tuple->mutable_operand(operand_index)));
1506       } else {
1507         operand_literal_values[operand_index] =
1508             Literal::CreateFromShapeWithUndeterminedLeafArrays(
1509                 ShapeUtil::GetSubshape(tuple->shape(), {operand_index}));
1510         operand_literals.push_back(&operand_literal_values[operand_index]);
1511       }
1512     }
1513   } else {
1514     for (auto operand : tuple->operands()) {
1515       operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
1516     }
1517   }
1518 
1519   if (evaluated_.contains(tuple)) {
1520     Literal new_result = LiteralUtil::MakeTuple(operand_literals);
1521     CHECK(new_result.IsDetermined(visitor_shape_index_));
1522     TF_RETURN_IF_ERROR(
1523         evaluated_[tuple].CopyFrom(new_result,
1524                                    /*dest_shape_index=*/visitor_shape_index_,
1525                                    /*src_shape_index=*/visitor_shape_index_));
1526   } else {
1527     evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
1528   }
1529   return OkStatus();
1530 }
1531 
1532 namespace {
1533 
1534 // These helper templates convert the data type and are intended to be used only
1535 // within the DFT implementation below. The special case is IRFFT, where the
1536 // specialization drops imaginary parts of complex values and returns real
1537 // numbers.
1538 template <typename ToType, typename FromType>
1539 struct TypeConverter {
GetAsxla::__anon9f8d6ac12511::TypeConverter1540   static inline ToType GetAs(FromType value) {
1541     return static_cast<ToType>(value);
1542   }
1543 };
1544 
1545 template <typename FromType>
1546 struct TypeConverter<float, FromType> {
GetAsxla::__anon9f8d6ac12511::TypeConverter1547   static inline float GetAs(FromType value) {
1548     return static_cast<float>(value.real());
1549   }
1550 };
1551 
1552 // This class implements the discrete Fourier transform. All transform types
1553 // (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary rank and
1554 // length of each dimension of the transform, and arbitrary layouts of the input
1555 // and output literals. The class template parameter must be a complex type, and
1556 // all internal calculations will be performed using this type.
1557 //
1558 // The input literal provides input data, which must be complex64 for FFT, IFFT,
1559 // IRFFT transforms and float for RFFT. The transform is computed over the
1560 // innermost dimensions of the input, thus the rank of the input data must be
1561 // same as fft_rank or larger. The input is expected to provide Ni values along
1562 // each transform axis with one exception: for IRFFT, only (N0 / 2) + 1 values
1563 // are needed along the X axis (the innermost index). To increase flexibility,
1564 // this implementation can handle mismatches between the input size and
1565 // transform lengths by either dropping extra input values or using zeroes in
1566 // place of missing input values as necessary. If the input data has rank higher
1567 // than the transform, the transform is applied for each valid combination of
1568 // the higher-ranking indices.
1569 //
1570 // The output contains complex64 values for FFT, IFFT, RFFT, and float values
1571 // for IRFFT. The rank of the output as well as the sizes of the dimensions
1572 // above the rank of the transform must match those of the input. Sizes of the
1573 // output's "fft_rank" innermost dimensions are expected to match the length of
1574 // the transform along respective axes with one exception: for RFFT, the output
1575 // is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
1576 // length(s) mismatch, the FFT output is trimmed to fit into the provided output
1577 // shape, or the output is padded with zero values appropriately.
1578 //
1579 // For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
1580 // input array will perform two transforms over the [][15][17] data in the sub
1581 // arrays [0][][] and [1][][], dropping the values along axis X and padding axis
1582 // Y with zeroes to create 16x16 working sets, and generating
1583 // complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
1584 // complex64[64][16][9] input array will use all input values and will produce
1585 // float[64][16][16] output.
1586 //
1587 // The implementation of the 1D transform for lengths, that are powers of 2, is
1588 // the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform
1589 // lengths, a straightforward, but slow, loop nest is used. The transforms of
1590 // higher ranks apply sets of 1D transforms along each axis. For example, the 2D
1591 // transform is computed by applying 1D transforms to each column followed by
1592 // applying 1D transforms to each row.
1593 //
1594 // In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
1595 // time, where Ni is the length of the transform's i-th dimension. However, for
1596 // dimension lengths, which are powers of 2, the run time along these dimensions
1597 // is reduced to log(Ni) in the summation, giving the runtime of
1598 // O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case.
1599 //
1600 template <typename ComplexType>
1601 class FftTransform {
1602  public:
FftTransform(HloInstruction * fft)1603   explicit FftTransform(HloInstruction* fft)
1604       : fft_type_(fft->fft_type()),
1605         fft_rank_(fft->fft_length().size()),
1606         fft_lengths_(fft->fft_length()) {
1607     // Make fft_lengths_[0] the minormost dimension.
1608     absl::c_reverse(fft_lengths_);
1609   }
1610 
ComputeFft(HloInstruction * fft,const Literal & input_literal,Literal * output_literal)1611   Status ComputeFft(HloInstruction* fft, const Literal& input_literal,
1612                     Literal* output_literal) {
1613     const Shape& input_shape = input_literal.shape();
1614     const Shape& output_shape = fft->shape();
1615 
1616     TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape));
1617 
1618     const auto fft_strides = ComputeStrides(fft_lengths_);
1619 
1620     // Working set size.
1621     const int64_t fft_size = fft_strides[fft_rank_];
1622 
1623     if (fft_size > 0) {
1624       // Linearized working data set.
1625       std::vector<ComplexType> data(fft_size);
1626 
1627       // Temporary buffer allocated once and used in 1D sweeps. For dimension
1628       // length values that are powers of 2, the buffer should be twice as large
1629       // to simultaneously hold input and output in Fft1D() above.
1630       int64_t buffer_size = 0;
1631       for (auto len : fft_lengths_) {
1632         int64_t size =
1633             absl::has_single_bit(static_cast<uint64_t>(len)) ? len * 2 : len;
1634         buffer_size = std::max(buffer_size, size);
1635       }
1636       std::vector<ComplexType> buffer(buffer_size);
1637 
1638       // Sizes of each axis of input and output literals.
1639       const auto input_lengths = GetDimensionLengths(input_literal);
1640       const auto output_lengths = GetDimensionLengths(*output_literal);
1641 
1642       // Strides for generating linearized indices into multidimensional arrays.
1643       const auto input_strides = ComputeStrides(input_lengths, input_literal);
1644       const auto output_strides =
1645           ComputeStrides(output_lengths, *output_literal);
1646 
1647       // Visit all elements in the dimensions with ranks above the FFT rank. For
1648       // each such element invoke the transform. Use separate indices for the
1649       // input and the output to allow different layouts.
1650       auto base_case = [&](int64_t axis, int64_t output_index,
1651                            int64_t input_index, bool within_src_bounds) {
1652         if (axis == fft_rank_ - 1) {
1653           // Base case: copy the data from the input literal, apply the
1654           // transform, and copy the result to the output literal.
1655           CHECK(within_src_bounds);
1656           bool input_is_zero = CopyDataFromInput(
1657               input_literal, input_index, fft_size, fft_lengths_, fft_strides,
1658               input_lengths, input_strides, absl::MakeSpan(data));
1659           if (!input_is_zero) {
1660             // Make 1D sweeps along each transform axis.
1661             Sweep(fft_lengths_, fft_strides, absl::MakeSpan(data),
1662                   absl::MakeSpan(buffer));
1663           }
1664           CopyDataToOutput(absl::MakeSpan(data), output_index, fft_lengths_,
1665                            fft_strides, output_lengths, output_strides,
1666                            output_literal);
1667           return true;
1668         }
1669         return false;
1670       };
1671       GenerateIndices(output_lengths, output_strides, input_lengths,
1672                       input_strides, input_shape.rank(), 0, 0, base_case);
1673     }
1674 
1675     return OkStatus();
1676   }
1677 
1678  private:
1679   // Common code used by 1D implementations, which copies data from the input to
1680   // the contiguous buffer. Returns true if all copied values are zero.
GatherToBuffer(absl::Span<ComplexType> data,int64_t length,int64_t start,int64_t stride,bool expand_input,absl::Span<ComplexType> buffer)1681   static bool GatherToBuffer(absl::Span<ComplexType> data, int64_t length,
1682                              int64_t start, int64_t stride, bool expand_input,
1683                              absl::Span<ComplexType> buffer) {
1684     CHECK_GE(buffer.size(), length);
1685     bool input_is_zero = true;
1686     const int64_t ub = expand_input ? length / 2 + 1 : length;
1687     CHECK_GE(data.size(), start + (ub - 1) * stride);
1688     for (int64_t k = 0; k < ub; k++) {
1689       ComplexType value = data[start + k * stride];
1690       input_is_zero &= value == ComplexType(0.0, 0.0);
1691       buffer[k] = value;
1692       if (expand_input) {
1693         // Use conjugates of the values at indices [1 ... (ub - 2)] when the
1694         // length is even and at indices [1 ... (ub - 1)] when the length is odd
1695         // to calculate missing values at indices [(length - 1) ... ub].
1696         if (k > 0 && k < (length - ub + 1)) {
1697           buffer[length - k] = std::conj(value);
1698         }
1699       }
1700     }
1701     return input_is_zero;
1702   }
1703 
1704   // Returns (conjugated, if 'inverse' is true) k-th twiddle for the given
1705   // length.
Twiddle(int64_t k,int64_t length,bool inverse)1706   static inline ComplexType Twiddle(int64_t k, int64_t length, bool inverse) {
1707     auto coeff = std::exp(ComplexType(0.0, -2.0 * M_PI * k / length));
1708     return inverse ? std::conj(coeff) : coeff;
1709   }
1710 
1711   // Straightforward implementation of 1D DFT transform of arbitrary length.
1712   // Uses passed-in start index and stride to gather inputs from the data vector
1713   // into the preallocated buffer, computes the result, and writes it back to
1714   // the same locations in the data vector. Runs in O(length^2) time.
1715   //
1716   // Parameters contract_output and expand_input are used to avoid unnecessary
1717   // calculations. When contract_output is set to true, then only (length / 2) +
1718   // 1 output values are computed. When expand_input is set to true, then
1719   // (length / 2) + 1 values from the data set are used to re-create the full
1720   // set of size 'length', on which the transform is then performed.
1721   //
NaiveDft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1722   static void NaiveDft1D(int64_t length, int64_t start, int64_t stride,
1723                          bool inverse, bool contract_output, bool expand_input,
1724                          absl::Span<ComplexType> data,
1725                          absl::Span<ComplexType> buffer) {
1726     const bool input_is_zero =
1727         GatherToBuffer(data, length, start, stride, expand_input, buffer);
1728 
1729     if (!input_is_zero) {
1730       const int64_t ub = contract_output ? length / 2 + 1 : length;
1731       for (int64_t k = 0; k < ub; k++) {
1732         ComplexType value = ComplexType(0.0, 0.0);
1733         for (int n = 0; n < length; n++) {
1734           value += buffer[n] * Twiddle(n * k, length, inverse);
1735         }
1736         data[start + k * stride] =
1737             inverse ? value / ComplexType(length, 0.0) : value;
1738       }
1739     }
1740   }
1741 
1742   // Non-recursive implementation of the Cooley-Tukey radix-2 decimation in
1743   // time. Performs 1D FFT transform for the lengths, which are powers of 2.
1744   // Runs in O(length * log(length)) time. Uses the same parameters as the naive
1745   // implementation above, except that the preallocated buffer must be at least
1746   // twice as big as the length of the transform, because the buffer is used to
1747   // hold both input and output values for each stage of the transform.
1748   //
Fft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1749   static void Fft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1750                     bool contract_output, bool expand_input,
1751                     absl::Span<ComplexType> data,
1752                     absl::Span<ComplexType> buffer) {
1753     CHECK(absl::has_single_bit(static_cast<uint64_t>(length)));
1754     const bool input_is_zero =
1755         GatherToBuffer(data, length, start, stride, expand_input, buffer);
1756 
1757     if (!input_is_zero) {
1758       auto generate_twiddles = [](int64_t length, bool inverse) {
1759         std::vector<ComplexType> twiddles;
1760         // Need only half the twiddles.
1761         for (int64_t k = 0; k < length / 2; k++) {
1762           twiddles.push_back(Twiddle(k, length, inverse));
1763         }
1764         return twiddles;
1765       };
1766 
1767       // Indices into the parts of the buffer used for input and output values.
1768       int64_t in_base = length;
1769       int64_t out_base = 0;
1770 
1771       // At each stage, we "split" the input data into num_blocks, with
1772       // block_size values in each block.
1773       for (int64_t num_blocks = 1; num_blocks < length; num_blocks *= 2) {
1774         // Swap input and output parts of the buffer.
1775         std::swap(in_base, out_base);
1776         auto twiddles = generate_twiddles(num_blocks * 2, inverse);
1777         const int64_t block_size = length / num_blocks;
1778         const int64_t next_iteration_block_size = block_size / 2;
1779         for (int64_t block = 0; block < num_blocks; block++) {
1780           const int64_t in_offset = in_base + block * block_size;
1781           const int64_t out_offset =
1782               out_base + block * next_iteration_block_size;
1783           // For each (even, odd) pair of values in the block, calculate two
1784           // output values as even + twiddle * odd and even - twiddle * odd.
1785           for (int64_t pair = 0; pair < block_size / 2; pair++) {
1786             const ComplexType even = buffer[in_offset + pair];
1787             const ComplexType odd = buffer[in_offset + block_size / 2 + pair];
1788             const ComplexType twiddled_odd = twiddles[block] * odd;
1789             buffer[out_offset + pair] = even + twiddled_odd;
1790             buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
1791           }
1792         }
1793       }
1794       // Copy computed result back to data.
1795       const int64_t ub = contract_output ? length / 2 + 1 : length;
1796       for (int64_t k = 0; k < ub; k++) {
1797         ComplexType value = buffer[out_base + k];
1798         data[start + k * stride] =
1799             inverse ? value / ComplexType(length, 0.0) : value;
1800       }
1801     }
1802   }
1803 
1804   // Determine, which implementation of 1D transform to use and call it.
Dft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1805   static void Dft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1806                     bool contract_output, bool expand_input,
1807                     absl::Span<ComplexType> data,
1808                     absl::Span<ComplexType> buffer) {
1809     if (absl::has_single_bit(static_cast<uint64_t>(length))) {
1810       Fft1D(length, start, stride, inverse, contract_output, expand_input, data,
1811             buffer);
1812     } else {
1813       NaiveDft1D(length, start, stride, inverse, contract_output, expand_input,
1814                  data, buffer);
1815     }
1816   }
1817 
1818   // Helper to reverse the order of dimension lengths in the passed-in literal.
GetDimensionLengths(const Literal & literal)1819   static std::vector<int64_t> GetDimensionLengths(const Literal& literal) {
1820     auto dimensions = literal.shape().dimensions();
1821     return std::vector<int64_t>(dimensions.rbegin(), dimensions.rend());
1822   }
1823 
1824   // Helper to compute strides for creating linear indices into multidimensional
1825   // data from the dimension lengths and the layout. Returns a new vector of
1826   // size lengths.size() + 1. The last element of the returned vector at index
1827   // [lengths.size()] contains the product of all dimension lengths.
ComputeStrides(const absl::Span<const int64_t> lengths,const Layout & layout)1828   static std::vector<int64_t> ComputeStrides(
1829       const absl::Span<const int64_t> lengths, const Layout& layout) {
1830     const int64_t num_dimensions = lengths.size();
1831 
1832     // Make sure that the layout length matches the number of dimensions.
1833     CHECK_EQ(num_dimensions, layout.minor_to_major_size());
1834 
1835     // Calculate strides using layout-specified ordering of the dimensions and
1836     // place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
1837     std::vector<int64_t> strides(num_dimensions + 1);
1838     int64_t stride = 1;
1839     for (int64_t i = 0; i < num_dimensions; i++) {
1840       // Reverse the ordering of the dimensions in the layout.
1841       const int64_t index = (num_dimensions - 1) - layout.minor_to_major(i);
1842       strides[index] = stride;
1843       stride *= lengths[index];
1844     }
1845     strides[num_dimensions] = stride;
1846 
1847     return strides;
1848   }
1849 
1850   // Compute strides as above using the default layout.
ComputeStrides(const absl::Span<const int64_t> lengths)1851   static std::vector<int64_t> ComputeStrides(
1852       const absl::Span<const int64_t> lengths) {
1853     return ComputeStrides(lengths,
1854                           LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
1855   }
1856 
1857   // Compute strides as above using the layout from the literal, if available.
ComputeStrides(const absl::Span<const int64_t> lengths,const Literal & literal)1858   static std::vector<int64_t> ComputeStrides(
1859       const absl::Span<const int64_t> lengths, const Literal& literal) {
1860     return literal.shape().has_layout()
1861                ? ComputeStrides(lengths, literal.shape().layout())
1862                : ComputeStrides(lengths);
1863   }
1864 
1865   // Make 1D sweeps along each transform axis.
Sweep(const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1866   void Sweep(const absl::Span<const int64_t> fft_lengths,
1867              const absl::Span<const int64_t> fft_strides,
1868              absl::Span<ComplexType> data, absl::Span<ComplexType> buffer) {
1869     const bool inverse =
1870         fft_type_ == FftType::IFFT || fft_type_ == FftType::IRFFT;
1871     const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1872     const bool output_is_truncated = fft_type_ == FftType::RFFT;
1873 
1874     // Recursively visit each column of the data along the sweep_axis. Calculate
1875     // linearized index of that column's first element and the stride, then
1876     // invoke 1D transform. For RFFT, avoid calculating unused output values:
1877     // first, compute only (length_x / 2) + 1 values along the X axis, then
1878     // limit the X coordinate to [0 ... (length / 2)] during the sweeps along
1879     // other axes. Similarly, for IRFFT sweep along higher dimensions first,
1880     // while keeping the X coordinate in the [0 ... (length / 2)] range, then
1881     // re-create negative frequencies omitted in the input and perform the
1882     // full-length transform along the X axis in the last sweep.
1883     std::function<void(int64_t, int64_t, int64_t)> sweep =
1884         [&](int64_t sweep_axis, int64_t axis, int64_t start) {
1885           if (axis < 0) {
1886             // Base case: invoke 1D transform.
1887             const int64_t length = fft_lengths[sweep_axis];
1888             const int64_t stride = fft_strides[sweep_axis];
1889             const bool expand_input = input_is_truncated && sweep_axis == 0;
1890             const bool contract_oputput =
1891                 output_is_truncated && sweep_axis == 0;
1892             Dft1D(length, start, stride, inverse, contract_oputput,
1893                   expand_input, data, buffer);
1894           } else if (axis == sweep_axis) {
1895             // Visit only the elements with coordinate 0 along the sweep axis.
1896             sweep(sweep_axis, axis - 1, start);
1897           } else {
1898             const int64_t length = fft_lengths[axis];
1899             const bool is_truncated = input_is_truncated || output_is_truncated;
1900             const int64_t ub =
1901                 is_truncated && axis == 0 ? (length / 2) + 1 : length;
1902             for (int64_t i = 0; i < ub; i++) {
1903               sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
1904             }
1905           }
1906         };
1907     if (input_is_truncated) {
1908       // Sweep along the X axis last for IRFFT.
1909       for (int64_t sweep_axis = fft_rank_ - 1; sweep_axis >= 0; sweep_axis--) {
1910         sweep(sweep_axis, fft_rank_ - 1, 0);
1911       }
1912     } else {
1913       // Sweep along the X axis first for RFFT. The order does not matter for
1914       // FFT and IFFT types; handle them here as well.
1915       for (int64_t sweep_axis = 0; sweep_axis < fft_rank_; sweep_axis++) {
1916         sweep(sweep_axis, fft_rank_ - 1, 0);
1917       }
1918     }
1919   }
1920 
1921   // This template generates two linearized indices, which can be used to access
1922   // multidimensional arrays. It uses a recursive function, which passes the
1923   // indices to the user-supplied callback function. The destination index is
1924   // always within dst_lengths[] bounds. The boolean parameter within_src_bounds
1925   // indicates whether the source index is within src_lengths[] bounds.
1926   //
1927   // The value returned from the callback function controls the recursion depth.
1928   // Returning true indicates that the base case had been hit and the recursion
1929   // stops. Otherwise, the recursion proceeds along the next less-major axis.
1930   //
1931   // For example, the base case when the axis value becomes negative invokes the
1932   // callback function for each possible index within dst_lengths[] bounds. The
1933   // base case when the axis value is equal to zero limits the indices to point
1934   // only to first elements along the minor-most dimension, allowing the
1935   // callback function to handle all values along the X axis.
1936   //
1937   template <typename BaseFn>
GenerateIndices(const absl::Span<const int64_t> dst_lengths,const absl::Span<const int64_t> dst_strides,const absl::Span<const int64_t> src_lengths,const absl::Span<const int64_t> src_strides,int64_t rank,int64_t dst_start,int64_t src_start,BaseFn && base)1938   static void GenerateIndices(const absl::Span<const int64_t> dst_lengths,
1939                               const absl::Span<const int64_t> dst_strides,
1940                               const absl::Span<const int64_t> src_lengths,
1941                               const absl::Span<const int64_t> src_strides,
1942                               int64_t rank, int64_t dst_start,
1943                               int64_t src_start, BaseFn&& base) {
1944     CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
1945     CHECK_GE(dst_lengths.size(), rank);
1946     CHECK_EQ(src_lengths.size() + 1, src_strides.size());
1947     CHECK_GE(src_lengths.size(), rank);
1948 
1949     std::function<void(int64_t, int64_t, int64_t, bool)> generate =
1950         [&](int64_t axis, int64_t dst_index, int64_t src_index,
1951             bool within_src_bounds) {
1952           if (!base(axis, dst_index, src_index, within_src_bounds)) {
1953             for (int64_t i = 0; i < dst_lengths[axis]; i++) {
1954               // Because the loop goes over dst_lengths[], the source index may
1955               // be out of src_lengths[] bounds. In this case, within_src_bounds
1956               // is false.
1957               within_src_bounds &= i < src_lengths[axis];
1958               generate(axis - 1, dst_index, src_index, within_src_bounds);
1959               dst_index += dst_strides[axis];
1960               src_index += src_strides[axis];
1961             }
1962           }
1963         };
1964     generate(rank - 1, dst_start, src_start, true);
1965   }
1966 
1967   // Copies the input data from a literal to a pre-allocated vector. The sizes
1968   // of the input and the transform do not need to match. For each axis of the
1969   // transform, any extra input values beyond the transform length are ignored.
1970   // Conversely, if the input does not contain enough elements along any axis,
1971   // the data is padded with zeroes.
1972   //
1973   // For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
1974   // where length_x is the size of the full transform along the X axis.
1975   //
1976   // The input literal may have a rank higher than the rank of the transform.
1977   // Passed-in input_index value points to the first element of the input
1978   // literal to be copied.
1979   //
1980   // Returns true if all values in the work data set are zeroes.
1981   //
1982   template <typename InputType>
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> input_lengths,const absl::Span<const int64_t> input_strides,absl::Span<ComplexType> data)1983   bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
1984                          int64_t fft_size,
1985                          const absl::Span<const int64_t> fft_lengths,
1986                          const absl::Span<const int64_t> fft_strides,
1987                          const absl::Span<const int64_t> input_lengths,
1988                          const absl::Span<const int64_t> input_strides,
1989                          absl::Span<ComplexType> data) {
1990     CHECK_GE(data.size(), fft_size);
1991 
1992     const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1993 
1994     // Recursively visit each transform dimension to copy input values to the
1995     // working data set. The base case handles inputs along the X axis.
1996     bool input_is_zero = true;
1997     const InputType* input_data = input_literal.data<InputType>().data();
1998     auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
1999                          bool within_src_bounds) {
2000       if (axis == 0) {
2001         // For IRFFT, the negative frequencies are only needed for the sweep
2002         // along the X axis, which is performed last. Leave this part of the
2003         // working set uninitialized until then.
2004         const int64_t length = fft_lengths[axis];
2005         const int64_t ub = input_is_truncated ? (length / 2) + 1 : length;
2006         for (int64_t i = 0; i < ub; i++) {
2007           ComplexType value = ComplexType(0);
2008           // Read input value only if the index is within bounds.
2009           if (within_src_bounds && i < input_lengths[axis]) {
2010             value = TypeConverter<ComplexType, InputType>::GetAs(
2011                 input_data[src_index + i * input_strides[axis]]);
2012             input_is_zero &= value == ComplexType(0.0, 0.0);
2013           }
2014           data[dst_index + i * fft_strides[axis]] = value;
2015         }
2016         return true;
2017       }
2018       return false;
2019     };
2020     GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
2021                     fft_rank_, 0, input_start, base_case);
2022     return input_is_zero;
2023   }
2024 
2025   // Copies the result of the transform to the literal output. The sizes of the
2026   // transform and output must match.
2027   //
2028   // For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
2029   // the size of the full transform along the X axis (the most minor dimension).
2030   //
2031   // The output literal may have a rank higher than the rank of the transform.
2032   // Passed-in output_index value points to the first element of the output
2033   // literal to be filled in.
2034   //
2035   template <typename OutputType>
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> output_lengths,const absl::Span<const int64_t> output_strides,Literal * output_literal)2036   void CopyDataToOutput(const absl::Span<ComplexType> data,
2037                         int64_t output_start,
2038                         const absl::Span<const int64_t> fft_lengths,
2039                         const absl::Span<const int64_t> fft_strides,
2040                         const absl::Span<const int64_t> output_lengths,
2041                         const absl::Span<const int64_t> output_strides,
2042                         Literal* output_literal) {
2043     const bool output_is_truncated = fft_type_ == FftType::RFFT;
2044 
2045     // Base case for recursive copy of the results to the output. The code
2046     // avoids making a recursive call for each output element by handling axis 0
2047     // in the loop (as opposed to making "axis < 0" to be the base case).
2048     OutputType* output_data = output_literal->data<OutputType>().data();
2049     auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
2050                          bool within_src_bounds) {
2051       if (axis == 0) {
2052         // Drop negative frequencies for RFFT.
2053         const int64_t length = fft_lengths[axis];
2054         const int64_t ub = output_is_truncated ? (length / 2) + 1 : length;
2055         for (int64_t i = 0; i < output_lengths[axis]; i++) {
2056           OutputType value = OutputType(0);
2057           // Read data only if the index is within bounds.
2058           if (within_src_bounds && i < ub) {
2059             value = TypeConverter<OutputType, ComplexType>::GetAs(
2060                 data[src_index + i * fft_strides[axis]]);
2061           }
2062           output_data[dst_index + i * output_strides[axis]] = value;
2063         }
2064         return true;
2065       }
2066       return false;
2067     };
2068     GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
2069                     fft_rank_, output_start, 0, base_case);
2070   }
2071 
2072   // Determine the type to use with the CopyDataFromInput<> template above.
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> input_lengths,const absl::Span<const int64_t> input_strides,absl::Span<ComplexType> data)2073   bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
2074                          int64_t fft_size,
2075                          const absl::Span<const int64_t> fft_lengths,
2076                          const absl::Span<const int64_t> fft_strides,
2077                          const absl::Span<const int64_t> input_lengths,
2078                          const absl::Span<const int64_t> input_strides,
2079                          absl::Span<ComplexType> data) {
2080     const bool input_is_float = fft_type_ == FftType::RFFT;
2081     if (input_is_float) {
2082       return CopyDataFromInput<float>(input_literal, input_start, fft_size,
2083                                       fft_lengths, fft_strides, input_lengths,
2084                                       input_strides, data);
2085     } else {
2086       return CopyDataFromInput<complex64>(input_literal, input_start, fft_size,
2087                                           fft_lengths, fft_strides,
2088                                           input_lengths, input_strides, data);
2089     }
2090   }
2091 
2092   // Determine the type to use with the CopyDataToOutput<> template above.
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> output_lengths,const absl::Span<const int64_t> output_strides,Literal * output_literal)2093   void CopyDataToOutput(const absl::Span<ComplexType> data,
2094                         int64_t output_start,
2095                         const absl::Span<const int64_t> fft_lengths,
2096                         const absl::Span<const int64_t> fft_strides,
2097                         const absl::Span<const int64_t> output_lengths,
2098                         const absl::Span<const int64_t> output_strides,
2099                         Literal* output_literal) {
2100     const bool output_is_float = fft_type_ == FftType::IRFFT;
2101     if (output_is_float) {
2102       CopyDataToOutput<float>(data, output_start, fft_lengths, fft_strides,
2103                               output_lengths, output_strides, output_literal);
2104     } else {
2105       CopyDataToOutput<complex64>(data, output_start, fft_lengths, fft_strides,
2106                                   output_lengths, output_strides,
2107                                   output_literal);
2108     }
2109   }
2110 
CheckParameters(const Shape & input_shape,const Shape & output_shape)2111   Status CheckParameters(const Shape& input_shape, const Shape& output_shape) {
2112     // Check FFT parameters.
2113     if (fft_rank_ <= 0) {
2114       return InvalidArgument("Zero or negative FFT rank.");
2115     }
2116     if (*absl::c_min_element(fft_lengths_) < 0) {
2117       return InvalidArgument("Negative FFT length.");
2118     }
2119 
2120     // Check input-related values.
2121     TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
2122     if (!input_shape.IsArray()) {
2123       return Unimplemented("Only array input shapes are supported.");
2124     }
2125     auto input_elt_type = input_shape.element_type();
2126     if (fft_type_ == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
2127       return InvalidArgument("Invalid input type: %d, must be %d (float).",
2128                              input_elt_type, PrimitiveType::F32);
2129     }
2130     if (fft_type_ != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
2131       return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
2132                              input_elt_type, PrimitiveType::C64);
2133     }
2134     const int64_t input_rank = input_shape.rank();
2135     if (input_rank < fft_rank_) {
2136       return InvalidArgument("Input shape rank is smaller than FFT rank.");
2137     }
2138 
2139     // Check output-related values.
2140     TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
2141     if (!output_shape.IsArray()) {
2142       return Unimplemented("Only array output shapes are supported.");
2143     }
2144     auto output_elt_type = output_shape.element_type();
2145     if (fft_type_ == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
2146       return InvalidArgument("Invalid output type: %d, must be %d (float).",
2147                              output_elt_type, PrimitiveType::F32);
2148     }
2149     if (fft_type_ != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
2150       return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
2151                              output_elt_type, PrimitiveType::C64);
2152     }
2153     const int64_t output_rank = output_shape.rank();
2154     if (output_rank < fft_rank_) {
2155       return InvalidArgument("Output shape rank is smaller than FFT rank.");
2156     }
2157 
2158     // Consistency of input and output parameters.
2159     if (input_rank != output_rank) {
2160       return InvalidArgument(
2161           "Ranks of input shape and output shape do not match.");
2162     }
2163     for (int64_t dim = 0; dim < input_rank - fft_rank_; dim++) {
2164       if (ShapeUtil::GetDimension(input_shape, dim) !=
2165           ShapeUtil::GetDimension(output_shape, dim)) {
2166         return InvalidArgument(
2167             "Higher dimension lengths of input shape and output shape do not "
2168             "match.");
2169       }
2170     }
2171 
2172     return OkStatus();
2173   }
2174 
2175  private:
2176   const FftType fft_type_;
2177   const int64_t fft_rank_;
2178   std::vector<int64_t> fft_lengths_;
2179 };
2180 
2181 }  // namespace
2182 
HandleFft(HloInstruction * fft)2183 Status HloEvaluator::HandleFft(HloInstruction* fft) {
2184   const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
2185   Literal output_literal = Literal::CreateFromShape(fft->shape());
2186 
2187   FftTransform<complex128> transform(fft);
2188   TF_RETURN_IF_ERROR(transform.ComputeFft(fft, input_literal, &output_literal));
2189   evaluated_[fft] = std::move(output_literal);
2190 
2191   return OkStatus();
2192 }
2193 
2194 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
2195 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)2196 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
2197     const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
2198   int64_t output_rank = output_shape.dimensions_size();
2199   std::vector<int64_t> index_base(output_rank, 0);
2200   std::vector<int64_t> index_count;
2201   index_count.reserve(output_rank);
2202   for (int64_t i = 0; i < output_rank; i++) {
2203     bool is_output_batch_dim =
2204         !absl::c_binary_search(dim_numbers.offset_dims(), i);
2205     index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
2206   }
2207 
2208   return {std::move(index_base), std::move(index_count),
2209           std::vector<int64_t>(output_rank, 1)};
2210 }
2211 
2212 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
2213 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64_t output_rank,absl::Span<const int64_t> slice_sizes,const GatherDimensionNumbers & dim_numbers)2214 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
2215     int64_t output_rank, absl::Span<const int64_t> slice_sizes,
2216     const GatherDimensionNumbers& dim_numbers) {
2217   std::vector<int64_t> index_base(output_rank, 0);
2218   std::vector<int64_t> index_count(output_rank, 1);
2219   int64_t slice_sizes_idx = 0;
2220   for (int64_t i = 0; i < output_rank; i++) {
2221     bool is_output_window_dim =
2222         absl::c_binary_search(dim_numbers.offset_dims(), i);
2223     if (is_output_window_dim) {
2224       while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
2225                                    slice_sizes_idx)) {
2226         slice_sizes_idx++;
2227       }
2228       index_count[i] = slice_sizes[slice_sizes_idx++];
2229     }
2230   }
2231 
2232   return {std::move(index_base), std::move(index_count),
2233           std::vector<int64_t>(output_rank, 1)};
2234 }
2235 
2236 // This functor computes the contribution of start_indices to an input index
2237 // corresponding to an output index.  That is, given an output index I, it picks
2238 // out the batch indices in I and uses them to look up a starting index, G, from
2239 // the start indices tensor, and expands G into the input space according to
2240 // start_index_map.
2241 class OutputBatchIndexToInputIndex {
2242  public:
2243   // The constructor does some setup work that is amortized across all
2244   // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)2245   explicit OutputBatchIndexToInputIndex(
2246       const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
2247       const Shape& output_shape, const Literal* start_indices)
2248       : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
2249     for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
2250       output_dim_is_batch_dims_.push_back(
2251           !absl::c_binary_search(dim_numbers_.offset_dims(), i));
2252     }
2253 
2254     for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
2255       int64_t index_of_input_dim_in_index_vector =
2256           std::distance(dim_numbers_.start_index_map().begin(),
2257                         absl::c_find(dim_numbers_.start_index_map(), i));
2258       if (index_of_input_dim_in_index_vector ==
2259           dim_numbers_.start_index_map_size()) {
2260         input_dim_value_to_index_vector_.push_back(-1);
2261       } else {
2262         input_dim_value_to_index_vector_.push_back(
2263             index_of_input_dim_in_index_vector);
2264       }
2265     }
2266 
2267     index_vector_index_.resize(start_indices_.shape().dimensions_size());
2268     input_index_.resize(input_shape.dimensions_size());
2269     int64_t index_vector_size =
2270         start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
2271     index_vector_.resize(index_vector_size);
2272   }
2273 
2274   // Returns the contribution of start_indices to the input index corresponding
2275   // to output_index.  See gather_inner_loop_body.
2276   //
2277   // This is conceptually  a stateless transformation from output_index to the
2278   // gather input index, but:
2279   //
2280   //  - Instead of allocating memory to represent the gather input index on
2281   //    every invocation we reuse the same storage for the result
2282   //    (input_index_), mutating it in place.
2283   //  - Instead of allocating buffers for temporary values like
2284   //    index_vector_index_ and index_vector on every invocation, we reuse the
2285   //    same storage for all invocations.
2286   //
2287   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> output_index)2288   StatusOr<absl::Span<const int64_t>> operator()(
2289       absl::Span<const int64_t> output_index) {
2290     PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
2291     TF_RETURN_IF_ERROR(FetchIndexVector());
2292     PropagateIndexVectorToInputIndex();
2293     return absl::Span<const int64_t>(input_index_);
2294   }
2295 
2296  private:
2297   // Propagates the batch dimensions from the output index into
2298   // index_vector_index_ by mutating index_vector_index_ in place.  Does not
2299   // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
2300   // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64_t> output_index)2301   void PropagateOutputIndexGatherDimsToIndexVectorIndex(
2302       absl::Span<const int64_t> output_index) {
2303     int64_t index_vector_index_i = 0;
2304     for (int64_t i = 0, e = output_index.size(); i < e; i++) {
2305       if (!output_dim_is_batch_dims_[i]) {
2306         continue;
2307       }
2308 
2309       if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
2310         index_vector_index_i++;
2311       }
2312 
2313       index_vector_index_[index_vector_index_i++] = output_index[i];
2314     }
2315   }
2316 
2317   // Populates index_vector_ by iterating over start_indices_ according to
2318   // index_vector_index_.
FetchIndexVector()2319   Status FetchIndexVector() {
2320     int64_t index_vector_dim = dim_numbers_.index_vector_dim();
2321     for (int64_t i = 0, e = index_vector_.size(); i < e; i++) {
2322       index_vector_index_[index_vector_dim] = i;
2323       auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
2324       TF_RET_CHECK(start_index.has_value());
2325       index_vector_[i] = *start_index;
2326     }
2327     return OkStatus();
2328   }
2329 
2330   // Populates input_index_.
PropagateIndexVectorToInputIndex()2331   void PropagateIndexVectorToInputIndex() {
2332     for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2333       if (input_dim_value_to_index_vector_[i] != -1) {
2334         input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
2335       }
2336 
2337       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2338       // remains 0, as set by the constructor.
2339     }
2340   }
2341 
2342   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
2343   // the input index from the index vector.  See
2344   // PropagateIndexVectorToInputIndex.
2345   std::vector<int64_t> input_dim_value_to_index_vector_;
2346 
2347   // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
2348   // dimension.
2349   std::vector<bool> output_dim_is_batch_dims_;
2350 
2351   // The buffer into which we construct an index into start_indices_ to fetch
2352   // the index vector.
2353   std::vector<int64_t> index_vector_index_;
2354 
2355   // The index vector fetched from start_indices_.
2356   std::vector<int64_t> index_vector_;
2357 
2358   // The result computed by this functor.  operator() returns a Span into
2359   // this vector.
2360   std::vector<int64_t> input_index_;
2361 
2362   const GatherDimensionNumbers& dim_numbers_;
2363   const Literal& start_indices_;
2364 };
2365 
2366 // This functor computes the contribution of the offset indices in an output
2367 // index to an input index.  That is, given an output index I it picks out the
2368 // output offset indices in I and expands it into an index into the input shape.
2369 class OutputOffsetIndexToInputIndex {
2370  public:
2371   // The constructor does some setup work that is amortized across all
2372   // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)2373   explicit OutputOffsetIndexToInputIndex(
2374       const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
2375       const Shape& output_shape) {
2376     std::vector<int64_t> window_index_to_output_index;
2377     int64_t output_index_count = 0;
2378     for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
2379       if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
2380         window_index_to_output_index.push_back(output_index_count++);
2381       } else {
2382         output_index_count++;
2383       }
2384     }
2385 
2386     int64_t window_dim_count = 0;
2387     for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
2388       if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
2389         input_dim_value_to_output_index_.push_back(-1);
2390       } else {
2391         input_dim_value_to_output_index_.push_back(
2392             window_index_to_output_index[window_dim_count++]);
2393       }
2394     }
2395 
2396     input_index_.resize(input_shape.dimensions_size());
2397   }
2398 
2399   // Returns the contribution of the window indices to the input index
2400   // corresponding to output_index.  See gather_inner_loop_body.
2401   //
2402   // This is conceptually a stateless transformation from output_index to the
2403   // window input index, but instead of allocating memory to represent the
2404   // gather input index on every invocation we reuse the same storage for the
2405   // result (input_index_), mutating it in place.
2406   //
2407   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> output_index)2408   StatusOr<absl::Span<const int64_t>> operator()(
2409       absl::Span<const int64_t> output_index) {
2410     PropagateOutputIndexWindowDimsToInputIndex(output_index);
2411     return absl::Span<const int64_t>(input_index_);
2412   }
2413 
2414   // Returns for a given 'input_dim' the corresponding output dimension index,
2415   // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64_t input_dim)2416   int64_t input_dim_value_to_output_index(int64_t input_dim) {
2417     return input_dim_value_to_output_index_[input_dim];
2418   }
2419 
2420  private:
2421   // Propagates window dimensions from the output index to input_index_ by
2422   // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64_t> output_index)2423   void PropagateOutputIndexWindowDimsToInputIndex(
2424       absl::Span<const int64_t> output_index) {
2425     for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2426       if (input_dim_value_to_output_index_[i] != -1) {
2427         input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
2428       }
2429 
2430       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2431       // remains 0, as set by the constructor.
2432     }
2433   }
2434 
2435   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
2436   // the input index from the output index. See
2437   // PropagateOutputIndexWindowDimsToInputIndex.
2438   std::vector<int64_t> input_dim_value_to_output_index_;
2439 
2440   // The result computed by this functor.  operator() returns a Span into
2441   // this vector.
2442   std::vector<int64_t> input_index_;
2443 };
2444 
2445 // Reshapes the gather indices input to have a trailing degenerate `1` dimension
2446 // if necessary.  Hands over the ownership of the newly created literal (if
2447 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64_t index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)2448 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
2449     int64_t index_vector_dim, const Literal& start_indices,
2450     Literal* reshaped_start_indices) {
2451   if (start_indices.shape().dimensions_size() != index_vector_dim) {
2452     return std::cref(start_indices);
2453   }
2454 
2455   std::vector<int64_t> new_shape(start_indices.shape().dimensions().begin(),
2456                                  start_indices.shape().dimensions().end());
2457   new_shape.push_back(1);
2458   TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
2459                       start_indices.Reshape(new_shape));
2460   return std::cref(*reshaped_start_indices);
2461 }
2462 
HandleGather(HloInstruction * gather)2463 Status HloEvaluator::HandleGather(HloInstruction* gather) {
2464   Literal result = Literal::CreateFromShape(gather->shape());
2465   const Shape& shape = gather->shape();
2466   const GatherDimensionNumbers& dim_numbers =
2467       gather->gather_dimension_numbers();
2468   const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
2469   Literal reshaped_start_indices;
2470   TF_ASSIGN_OR_RETURN(
2471       const Literal& start_indices,
2472       ReshapedGatherIndices(dim_numbers.index_vector_dim(),
2473                             GetEvaluatedLiteralFor(gather->operand(1)),
2474                             &reshaped_start_indices));
2475 
2476   // We iterate over the gather dimensions in the output shape in an outer loop
2477   // nest, and iterate over the window dimensions in the output shape in an
2478   // inner loop nest.
2479 
2480   ShapeUtil::IndexIterationSpace start_indices_iteration_space =
2481       IterationSpaceForOutputBatchIndices(shape, dim_numbers);
2482   ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
2483       IterationSpaceForOutputOffsetIndices(
2484           shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
2485 
2486   // Scratch buffers that hold an index in the output shape and the
2487   // corresponding index in the input shape.
2488   std::vector<int64_t> input_index(operand.shape().dimensions_size());
2489   std::vector<int64_t> output_index(gather->shape().dimensions_size());
2490   std::vector<int64_t> input_index_clamped(operand.shape().dimensions_size());
2491 
2492   OutputBatchIndexToInputIndex output_batch_index_to_input_index(
2493       &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
2494       /*output_shape=*/shape, &start_indices);
2495   OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
2496       gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
2497       /*output_shape=*/shape);
2498 
2499   const Shape& operand_shape = operand.shape();
2500   if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2501     evaluated_[gather] = std::move(result);
2502     return OkStatus();
2503   }
2504 
2505   auto gather_inner_loop_body =
2506       [&](absl::Span<const int64_t> output_window_index,
2507           absl::Span<const int64_t> input_gather_index,
2508           absl::Span<const int64_t> output_gather_index) -> StatusOr<bool> {
2509     TF_ASSIGN_OR_RETURN(
2510         absl::Span<const int64_t> input_window_index,
2511         output_offset_index_to_input_index(output_window_index));
2512     for (int i = 0, e = output_index.size(); i < e; i++) {
2513       output_index[i] = output_gather_index[i] + output_window_index[i];
2514       DCHECK_LT(output_index[i], shape.dimensions(i));
2515     }
2516     for (int i = 0, e = input_gather_index.size(); i < e; i++) {
2517       int64_t output_dim =
2518           output_offset_index_to_input_index.input_dim_value_to_output_index(i);
2519       // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
2520       // we set the iteration index to 0, so for the purpose of the following
2521       // calculations we can consider the output dimension size to be 1.
2522       int64_t output_dim_size =
2523           output_dim == -1 ? 1 : shape.dimensions(output_dim);
2524       // Clamp the gather index so that the gather region fits in the operand.
2525       // input_index_clamped[i] = clamp(input_gather_index[i], 0,
2526       //                                       operand_shape.dimensions(i) -
2527       //                                       output_dim_size);
2528       input_index_clamped[i] =
2529           std::min(operand_shape.dimensions(i) - output_dim_size,
2530                    std::max(int64_t{0}, input_gather_index[i]));
2531     }
2532     for (int i = 0, e = input_index.size(); i < e; i++) {
2533       input_index[i] = input_index_clamped[i] + input_window_index[i];
2534       DCHECK_GE(input_index[i], 0);
2535       DCHECK_LT(input_index[i], operand_shape.dimensions(i));
2536     }
2537     TF_RETURN_IF_ERROR(
2538         result.CopyElementFrom(operand, input_index, output_index));
2539     return true;
2540   };
2541 
2542   auto gather_outer_loop_body =
2543       [&](absl::Span<const int64_t> output_gather_index) -> StatusOr<bool> {
2544     TF_ASSIGN_OR_RETURN(absl::Span<const int64_t> input_gather_index,
2545                         output_batch_index_to_input_index(output_gather_index));
2546     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2547         shape, offset_indices_iteration_space,
2548         std::bind(gather_inner_loop_body, std::placeholders::_1,
2549                   input_gather_index, output_gather_index)));
2550     return true;
2551   };
2552 
2553   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2554       shape, start_indices_iteration_space, gather_outer_loop_body));
2555   evaluated_[gather] = std::move(result);
2556   return OkStatus();
2557 }
2558 
2559 namespace {
2560 // Reshapes the scatter indices input to have a trailing degenerate `1`
2561 // dimension if necessary.  Hands over the ownership of the newly created
2562 // literal (if there is one) to `reshaped_indices`.
ReshapedScatterIndices(int64_t index_vector_dim,const Literal & indices,Literal * reshaped_indices)2563 StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
2564     int64_t index_vector_dim, const Literal& indices,
2565     Literal* reshaped_indices) {
2566   if (indices.shape().dimensions_size() != index_vector_dim) {
2567     return std::cref(indices);
2568   }
2569 
2570   std::vector<int64_t> new_shape(indices.shape().dimensions().begin(),
2571                                  indices.shape().dimensions().end());
2572   new_shape.push_back(1);
2573   TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
2574   return std::cref(*reshaped_indices);
2575 }
2576 
2577 template <bool kForUpdateWindowIndices>
GetIterationSpaceImpl(absl::Span<const int64_t> updates_dims,const ScatterDimensionNumbers & dim_numbers)2578 ShapeUtil::IndexIterationSpace GetIterationSpaceImpl(
2579     absl::Span<const int64_t> updates_dims,
2580     const ScatterDimensionNumbers& dim_numbers) {
2581   int64_t updates_rank = updates_dims.size();
2582   std::vector<int64_t> index_base(updates_rank, 0);
2583   std::vector<int64_t> index_count(updates_rank, 1);
2584   for (int64_t i = 0; i < updates_rank; i++) {
2585     // Use if constexpr when we can use c++17 or above.
2586     if (kForUpdateWindowIndices) {
2587       bool is_update_window_dim =
2588           absl::c_binary_search(dim_numbers.update_window_dims(), i);
2589       if (is_update_window_dim) {
2590         index_count[i] = updates_dims[i];
2591       }
2592     } else {
2593       bool is_update_scatter_dim =
2594           !absl::c_binary_search(dim_numbers.update_window_dims(), i);
2595       if (is_update_scatter_dim) {
2596         index_count[i] = updates_dims[i];
2597       }
2598     }
2599   }
2600   return {std::move(index_base), std::move(index_count),
2601           std::vector<int64_t>(updates_rank, 1)};
2602 }
2603 
2604 // Returns an ShapeUtil::IndexIterationSpace that iterates over the update
2605 // scatter dimensions while keeping the rest of the update dimensions clamped
2606 // to 0.
IterationSpaceForUpdateScatterIndices(absl::Span<const int64_t> updates_dims,const ScatterDimensionNumbers & dim_numbers)2607 ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices(
2608     absl::Span<const int64_t> updates_dims,
2609     const ScatterDimensionNumbers& dim_numbers) {
2610   return GetIterationSpaceImpl</*kForUpdateWindowIndices=*/false>(updates_dims,
2611                                                                   dim_numbers);
2612 }
2613 
2614 // Return an ShapeUtil::IndexIterationSpace that iterates over the update
2615 // window dimensions while keeping the rest of the update dimensions clamped
2616 // to 0.
IterationSpaceForUpdateWindowIndices(absl::Span<const int64_t> updates_dims,const ScatterDimensionNumbers & dim_numbers)2617 ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices(
2618     absl::Span<const int64_t> updates_dims,
2619     const ScatterDimensionNumbers& dim_numbers) {
2620   return GetIterationSpaceImpl</*kForUpdateWindowIndices=*/true>(updates_dims,
2621                                                                  dim_numbers);
2622 }
2623 
2624 // This functor computes the contribution of scatter_indices to an input index
2625 // corresponding to an update index.  That is, given an update index I, it
2626 // picks out the scatter indices in I and uses them to look up a scatter
2627 // index, S, from the scatter indices tensor, and expands S into the input
2628 // space according to scatter_dims_to_operand_dims.
2629 //
2630 // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex
2631 // that does the corresponding function for Gather.
2632 class UpdateScatterIndexToInputIndex {
2633  public:
2634   // The constructor does some setup work that is amortized across all
2635   // iterations.
UpdateScatterIndexToInputIndex(const ScatterDimensionNumbers & dim_numbers,int64_t input_rank,int64_t updates_rank,const Literal * scatter_indices)2636   explicit UpdateScatterIndexToInputIndex(
2637       const ScatterDimensionNumbers& dim_numbers, int64_t input_rank,
2638       int64_t updates_rank, const Literal* scatter_indices)
2639       : dim_numbers_(dim_numbers), scatter_indices_(*scatter_indices) {
2640     for (int64_t i = 0; i < updates_rank; i++) {
2641       update_dim_is_scatter_dims_.push_back(
2642           !absl::c_binary_search(dim_numbers_.update_window_dims(), i));
2643     }
2644 
2645     for (int64_t i = 0; i < input_rank; i++) {
2646       int64_t index_of_input_dim_in_index_vector =
2647           FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i);
2648       if (index_of_input_dim_in_index_vector ==
2649           dim_numbers_.scatter_dims_to_operand_dims_size()) {
2650         input_dim_value_to_index_vector_.push_back(-1);
2651       } else {
2652         input_dim_value_to_index_vector_.push_back(
2653             index_of_input_dim_in_index_vector);
2654       }
2655     }
2656 
2657     index_vector_index_.resize(scatter_indices_.shape().dimensions_size());
2658     input_index_.resize(input_rank);
2659     int64_t index_vector_size =
2660         scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
2661     index_vector_.resize(index_vector_size);
2662   }
2663 
2664   // Returns the contribution of scatter_indices to the input index
2665   // corresponding to update_index.  See scatter_inner_loop_body.
2666   //
2667   // This is conceptually  a stateless transformation from update_index to the
2668   // scatter input index, but:
2669   //
2670   //  - Instead of allocating memory to represent the scatter input index on
2671   //    every invocation we reuse the same storage for the result
2672   //    (input_index_), mutating it in place.
2673   //  - Instead of allocating buffers for temporary values like
2674   //    index_vector_index_ and index_vector on every invocation, we reuse the
2675   //    same storage for all invocations.
2676   //
2677   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> update_index)2678   StatusOr<absl::Span<const int64_t>> operator()(
2679       absl::Span<const int64_t> update_index) {
2680     PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
2681     TF_RETURN_IF_ERROR(FetchIndexVector());
2682     PropagateIndexVectorToInputIndex();
2683     return absl::Span<const int64_t>(input_index_);
2684   }
2685 
2686  private:
2687   // Propagates the scatter index dimensions from the update index into
2688   // index_vector_index_ by mutating index_vector_index_ in place.  Does not
2689   // update the dim_numbers.index_vector_dim() dimension -- that's the
2690   // dimension we iterate over in FetchIndexVector.
PropagateUpdateIndexScatterDimsToIndexVectorIndex(absl::Span<const int64_t> update_index)2691   void PropagateUpdateIndexScatterDimsToIndexVectorIndex(
2692       absl::Span<const int64_t> update_index) {
2693     int64_t index_vector_index_i = 0;
2694     for (int64_t i = 0, e = update_index.size(); i < e; i++) {
2695       if (!update_dim_is_scatter_dims_[i]) {
2696         continue;
2697       }
2698 
2699       if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
2700         index_vector_index_i++;
2701       }
2702 
2703       index_vector_index_[index_vector_index_i++] = update_index[i];
2704     }
2705   }
2706 
2707   // Populates index_vector_ by iterating over scatter_indices_ according to
2708   // index_vector_index_.
FetchIndexVector()2709   Status FetchIndexVector() {
2710     int64_t index_vector_dim = dim_numbers_.index_vector_dim();
2711     for (int64_t i = 0, e = index_vector_.size(); i < e; i++) {
2712       index_vector_index_[index_vector_dim] = i;
2713       index_vector_[i] =
2714           *scatter_indices_.GetIntegralAsS64(index_vector_index_);
2715     }
2716     return OkStatus();
2717   }
2718 
2719   // Populates input_index_.
PropagateIndexVectorToInputIndex()2720   void PropagateIndexVectorToInputIndex() {
2721     for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2722       if (input_dim_value_to_index_vector_[i] != -1) {
2723         input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
2724       }
2725 
2726       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2727       // remains 0, as set by the constructor.
2728     }
2729   }
2730 
2731   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i
2732   // of the input index from the index vector.  See
2733   // PropagateIndexVectorToInputIndex.
2734   std::vector<int64_t> input_dim_value_to_index_vector_;
2735 
2736   // update_dim_is_scatter_dims_[i] is true iff the update index i is a
2737   // scatter dimension.
2738   std::vector<bool> update_dim_is_scatter_dims_;
2739 
2740   // The buffer into which we construct an index into scatter_indices_ to
2741   // fetch the index vector.
2742   std::vector<int64_t> index_vector_index_;
2743 
2744   // The index vector fetched from scatter_indices_.
2745   std::vector<int64_t> index_vector_;
2746 
2747   // The result computed by this functor.  operator() returns a Span
2748   // into this vector.
2749   std::vector<int64_t> input_index_;
2750 
2751   const ScatterDimensionNumbers& dim_numbers_;
2752   const Literal& scatter_indices_;
2753 };
2754 
2755 // This functor computes the contribution of the window indices in an update
2756 // index to an input index.  That is, given an update index I it picks out the
2757 // update window indices in I and expands it into a window index into the
2758 // input shape.
2759 //
2760 // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex
2761 // that does the corresponding function for Gather.
2762 class UpdateWindowIndexToInputIndex {
2763  public:
2764   // The constructor does some setup work that is amortized across all
2765   // iterations.
UpdateWindowIndexToInputIndex(const ScatterDimensionNumbers & dim_numbers,int64_t input_rank,int64_t update_rank)2766   explicit UpdateWindowIndexToInputIndex(
2767       const ScatterDimensionNumbers& dim_numbers, int64_t input_rank,
2768       int64_t update_rank) {
2769     std::vector<int64_t> window_index_to_update_index;
2770     int64_t update_index_count = 0;
2771     for (int64_t i = 0; i < update_rank; i++) {
2772       if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
2773         window_index_to_update_index.push_back(update_index_count++);
2774       } else {
2775         update_index_count++;
2776       }
2777     }
2778 
2779     int64_t window_dim_count = 0;
2780     for (int64_t i = 0; i < input_rank; i++) {
2781       if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
2782         input_dim_value_to_update_index_.push_back(-1);
2783       } else {
2784         input_dim_value_to_update_index_.push_back(
2785             window_index_to_update_index[window_dim_count++]);
2786       }
2787     }
2788 
2789     input_index_.resize(input_rank);
2790   }
2791 
2792   // Returns the contribution of the window indices to the input index
2793   // corresponding to update_index.  See scatter_inner_loop_body.
2794   //
2795   // This is conceptually a stateless transformation from update_index to the
2796   // window input index, but instead of allocating memory to represent the
2797   // scatter input index on every invocation we reuse the same storage for the
2798   // result (input_index_), mutating it in place.
2799   //
2800   // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> update_index)2801   StatusOr<absl::Span<const int64_t>> operator()(
2802       absl::Span<const int64_t> update_index) {
2803     PropagateUpdateIndexWindowDimsToInputIndex(update_index);
2804     return absl::Span<const int64_t>(input_index_);
2805   }
2806 
2807   // Returns for a given 'input_dim' the corresponding update dimension index,
2808   // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_update_index(int64_t input_dim)2809   int64_t input_dim_value_to_update_index(int64_t input_dim) {
2810     return input_dim_value_to_update_index_[input_dim];
2811   }
2812 
2813  private:
2814   // Propagates window dimensions from the update index to input_index_ by
2815   // mutating input_index_ in place.
PropagateUpdateIndexWindowDimsToInputIndex(absl::Span<const int64_t> update_index)2816   void PropagateUpdateIndexWindowDimsToInputIndex(
2817       absl::Span<const int64_t> update_index) {
2818     for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2819       if (input_dim_value_to_update_index_[i] != -1) {
2820         input_index_[i] = update_index[input_dim_value_to_update_index_[i]];
2821       }
2822 
2823       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2824       // remains 0, as set by the constructor.
2825     }
2826   }
2827 
2828   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i
2829   // of the input index from the update index. See
2830   // PropagateUpdateIndexWindowDimsToInputIndex.
2831   std::vector<int64_t> input_dim_value_to_update_index_;
2832 
2833   // The result computed by this functor.  operator() returns a Span
2834   // into this vector.
2835   std::vector<int64_t> input_index_;
2836 };
2837 }  // namespace
2838 
HandleScatter(HloInstruction * hlo)2839 Status HloEvaluator::HandleScatter(HloInstruction* hlo) {
2840   auto* scatter = DynCast<HloScatterInstruction>(hlo);
2841   const ScatterDimensionNumbers& dim_numbers =
2842       scatter->scatter_dimension_numbers();
2843   absl::InlinedVector<const Literal*, 1> operands;
2844   operands.reserve(scatter->scatter_operand_count());
2845   for (HloInstruction* operand_inst : scatter->scatter_operands()) {
2846     operands.push_back(&GetEvaluatedLiteralFor(operand_inst));
2847   }
2848   Literal reshaped_scatter_indices;
2849   TF_ASSIGN_OR_RETURN(
2850       const Literal& scatter_indices,
2851       ReshapedScatterIndices(dim_numbers.index_vector_dim(),
2852                              GetEvaluatedLiteralFor(scatter->scatter_indices()),
2853                              &reshaped_scatter_indices));
2854   absl::InlinedVector<const Literal*, 1> updates;
2855   updates.reserve(operands.size());
2856   for (HloInstruction* updates_inst : scatter->scatter_updates()) {
2857     updates.push_back(&GetEvaluatedLiteralFor(updates_inst));
2858   }
2859   auto updates_dims = updates[0]->shape().dimensions();
2860   auto operand_dims = operands[0]->shape().dimensions();
2861 
2862   ShapeUtil::IndexIterationSpace scatter_indices_iteration_space =
2863       IterationSpaceForUpdateScatterIndices(updates_dims, dim_numbers);
2864   ShapeUtil::IndexIterationSpace window_indices_iteration_space =
2865       IterationSpaceForUpdateWindowIndices(updates_dims, dim_numbers);
2866 
2867   std::vector<int64_t> input_index(operand_dims.size());
2868   std::vector<int64_t> update_index(updates_dims.size());
2869 
2870   UpdateScatterIndexToInputIndex update_scatter_index_to_input_index(
2871       scatter->scatter_dimension_numbers(),
2872       /*input_rank=*/operand_dims.size(), updates_dims.size(),
2873       &scatter_indices);
2874   UpdateWindowIndexToInputIndex update_window_index_to_input_index(
2875       scatter->scatter_dimension_numbers(),
2876       /*input_rank=*/operand_dims.size(), updates_dims.size());
2877 
2878   // Initialize the result with the operand. This makes it easier to handle
2879   // the updates even when the indices are repeated.
2880   Literal result = operands.size() > 1 ? LiteralUtil::MakeTuple(operands)
2881                                        : operands[0]->Clone();
2882   auto maybe_slice = [](MutableLiteralBase& literal, int idx) {
2883     if (literal.shape().IsTuple()) {
2884       return MutableBorrowingLiteral(&literal, {idx});
2885     }
2886     DCHECK_EQ(idx, 0);
2887     return MutableBorrowingLiteral(&literal);
2888   };
2889 
2890   HloEvaluator embedded_evaluator;
2891   auto scatter_inner_loop_body =
2892       [&](absl::Span<const int64_t> update_window_index,
2893           absl::Span<const int64_t> input_scatter_index,
2894           absl::Span<const int64_t> update_scatter_index) -> StatusOr<bool> {
2895     TF_ASSIGN_OR_RETURN(
2896         absl::Span<const int64_t> input_window_index,
2897         update_window_index_to_input_index(update_window_index));
2898     for (int i = 0, e = update_index.size(); i < e; i++) {
2899       update_index[i] = update_scatter_index[i] + update_window_index[i];
2900       DCHECK_LT(update_index[i], updates_dims[i]);
2901     }
2902     for (int i = 0, e = input_scatter_index.size(); i < e; i++) {
2903       int64_t update_dim =
2904           update_window_index_to_input_index.input_dim_value_to_update_index(i);
2905       // If 'update_dim' is -1, it means 'i' is an elided window dim. This
2906       // means we set the iteration index to 0, so for the purpose of the
2907       // following calculations we can consider the update dimension size to
2908       // be 1.
2909       int64_t update_dim_size = update_dim == -1 ? 1 : updates_dims[update_dim];
2910       // If any part of the update region is out-of-bounds, then do not
2911       // perform any update on the input.
2912       if ((input_scatter_index[i] < 0) ||
2913           (input_scatter_index[i] > operand_dims[i] - update_dim_size)) {
2914         return true;
2915       }
2916     }
2917     for (int i = 0, e = input_index.size(); i < e; i++) {
2918       input_index[i] = input_scatter_index[i] + input_window_index[i];
2919     }
2920 
2921     absl::InlinedVector<Literal, 2> to_apply_args;
2922     to_apply_args.reserve(operands.size() + updates.size());
2923     for (int i = 0, n = operands.size(); i < n; ++i) {
2924       to_apply_args.push_back(
2925           LiteralUtil::GetScalarLiteral(maybe_slice(result, i), input_index));
2926     }
2927     for (int i = 0, n = operands.size(); i < n; ++i) {
2928       to_apply_args.push_back(
2929           LiteralUtil::GetScalarLiteral(*updates[i], update_index));
2930     }
2931     Literal updated_result =
2932         embedded_evaluator.Evaluate(*scatter->to_apply(), to_apply_args)
2933             .value();
2934     // Clear visit states so that the we can use the evaluate again on the
2935     // same computation.
2936     embedded_evaluator.ResetVisitStates();
2937     for (int i = 0, n = operands.size(); i < n; ++i) {
2938       auto result_slice = maybe_slice(result, i);
2939       LiteralUtil::SetScalarLiteral(result_slice, input_index,
2940                                     maybe_slice(updated_result, i));
2941     }
2942     return true;
2943   };
2944 
2945   auto scatter_outer_loop_body =
2946       [&](absl::Span<const int64_t> update_scatter_index) -> StatusOr<bool> {
2947     TF_ASSIGN_OR_RETURN(
2948         absl::Span<const int64_t> input_scatter_index,
2949         update_scatter_index_to_input_index(update_scatter_index));
2950     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2951         updates[0]->shape(), window_indices_iteration_space,
2952         [&](absl::Span<const int64_t> update_window_index) {
2953           return scatter_inner_loop_body(
2954               update_window_index, input_scatter_index, update_scatter_index);
2955         }));
2956     return true;
2957   };
2958 
2959   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2960       updates[0]->shape(), scatter_indices_iteration_space,
2961       scatter_outer_loop_body));
2962   evaluated_[scatter] = std::move(result);
2963   return OkStatus();
2964 }
2965 
HandleBroadcast(HloInstruction * broadcast)2966 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
2967   const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
2968   TF_RET_CHECK(broadcast->shape().element_type() ==
2969                operand.shape().element_type())
2970       << " broadcast from a different data type is not supported";
2971   TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
2972       << "broadcast dimensions is of size: " << broadcast->dimensions().size()
2973       << " and rank of operand_to_broadcast is: " << operand.shape().rank();
2974   // Checks that operand's dimensions are the same as the broadcast's
2975   // dimensions along the dimensions to be broadcasted.
2976   for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
2977     auto operand_dim_size = operand.shape().dimensions(i);
2978     auto broadcast_dim_size =
2979         broadcast->shape().dimensions(broadcast->dimensions(i));
2980     TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
2981         "Operand dimension %d is broadcast to output dimension %d, but the "
2982         "sizes of these two dims do not match (%d vs %d): %s",
2983         i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
2984         broadcast->ToString());
2985   }
2986 
2987   TF_ASSIGN_OR_RETURN(
2988       evaluated_[broadcast],
2989       operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
2990 
2991   return OkStatus();
2992 }
2993 
HandleAfterAll(HloInstruction * after_all)2994 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
2995   evaluated_[after_all] = LiteralUtil::CreateToken();
2996   return OkStatus();
2997 }
2998 
HandleAddDependency(HloInstruction * add_dependency)2999 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
3000   // AddDedendency just forwards its zero-th operand.
3001   evaluated_[add_dependency] =
3002       GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
3003   return OkStatus();
3004 }
3005 
HandleGetTupleElement(HloInstruction * get_tuple_element)3006 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
3007   const auto result_shape = get_tuple_element->shape();
3008   const int64_t index = get_tuple_element->tuple_index();
3009 
3010   auto operand = get_tuple_element->operand(0);
3011   TF_ASSIGN_OR_RETURN(
3012       auto inferred_return_shape,
3013       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
3014   TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
3015       << "return shape set to: " << ShapeUtil::HumanString(result_shape)
3016       << " but is inferred to be: "
3017       << ShapeUtil::HumanString(inferred_return_shape);
3018 
3019   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
3020 
3021   evaluated_[get_tuple_element] =
3022       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
3023   return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
3024                                                 /*dest_shape_index=*/{},
3025                                                 /*src_shape_index=*/{index});
3026 }
3027 
HandleCopy(HloInstruction * copy)3028 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
3029   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
3030   evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
3031   return OkStatus();
3032 }
3033 
HandleAsyncStart(HloInstruction * async_start)3034 Status HloEvaluator::HandleAsyncStart(HloInstruction* async_start) {
3035   std::vector<const Literal*> arg_literals;
3036   arg_literals.reserve(async_start->operands().size());
3037   for (auto operand : async_start->operands()) {
3038     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
3039     arg_literals.push_back(&arg_literal);
3040   }
3041 
3042   HloEvaluator embedded_evaluator;
3043   embedded_evaluator.set_dynamic_dimension_inference(
3044       dynamic_dimension_inference_);
3045   TF_ASSIGN_OR_RETURN(
3046       Literal result,
3047       embedded_evaluator.Evaluate(*async_start->async_wrapped_computation(),
3048                                   arg_literals));
3049 
3050   evaluated_[async_start] = Literal(async_start->shape());
3051   // Copy the operand values to the index {0, i} of the output.
3052   for (int i = 0; i < arg_literals.size(); ++i) {
3053     TF_RETURN_IF_ERROR(evaluated_[async_start].CopyFrom(
3054         *arg_literals[i], /*dest_shape_index=*/{0, i},
3055         /*src_shape_index=*/{}));
3056   }
3057   // Move the output value to the index {1} of the output.
3058   TF_RETURN_IF_ERROR(evaluated_[async_start].MoveFrom(
3059       std::move(result), /*dest_shape_index=*/{1}));
3060 
3061   return OkStatus();
3062 }
3063 
HandleAsyncUpdate(HloInstruction * async_update)3064 Status HloEvaluator::HandleAsyncUpdate(HloInstruction* async_update) {
3065   const Literal& operand_tuple_literal =
3066       GetEvaluatedLiteralFor(async_update->operand(0));
3067   evaluated_[async_update] = Literal(async_update->shape());
3068   TF_RETURN_IF_ERROR(evaluated_[async_update].CopyFrom(operand_tuple_literal,
3069                                                        /*dest_shape_index=*/{},
3070                                                        /*src_shape_index=*/{}));
3071   return OkStatus();
3072 }
3073 
HandleAsyncDone(HloInstruction * async_done)3074 Status HloEvaluator::HandleAsyncDone(HloInstruction* async_done) {
3075   const Literal& operand_tuple_literal =
3076       GetEvaluatedLiteralFor(async_done->operand(0));
3077   evaluated_[async_done] = Literal(async_done->shape());
3078   TF_RETURN_IF_ERROR(evaluated_[async_done].CopyFrom(operand_tuple_literal,
3079                                                      /*dest_shape_index=*/{},
3080                                                      /*src_shape_index=*/{1}));
3081   return OkStatus();
3082 }
3083 
HandleCopyStart(HloInstruction * copy_start)3084 Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
3085   if (copy_start->user_count() != 1 ||
3086       copy_start->users().at(0)->opcode() != HloOpcode::kCopyDone) {
3087     return tensorflow::errors::FailedPrecondition(
3088         "Cannot evaluate a kCopyStart that doesn't have a single kCopyDone "
3089         "user.");
3090   }
3091 
3092   // The context in index {2} is undefined, but since we can't represent
3093   // undefined values using a Literal, we just use 0. This should be safe though
3094   // since we ensure that the only user of a kCopyStart is a kCopyDone which
3095   // consumes the context. Also note that MakeTuple copies its arguments, so
3096   // this is memory-safe.
3097   const Literal context_literal = LiteralUtil::CreateR0<uint32_t>(0);
3098   evaluated_[copy_start] = LiteralUtil::MakeTuple(
3099       {&GetEvaluatedLiteralFor(copy_start->operand(0)),
3100        &GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
3101   return OkStatus();
3102 }
3103 
HandleCopyDone(HloInstruction * copy_done)3104 Status HloEvaluator::HandleCopyDone(HloInstruction* copy_done) {
3105   const HloInstruction* operand = copy_done->operand(0);
3106   if (operand->opcode() != HloOpcode::kCopyStart) {
3107     return tensorflow::errors::FailedPrecondition(
3108         "Cannot evaluate a kCopyDone that doesn't have a kCopyStart as "
3109         "operand.");
3110   }
3111 
3112   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
3113   evaluated_[copy_done] =
3114       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), /*index=*/0));
3115   TF_RETURN_IF_ERROR(evaluated_[copy_done].CopyFrom(operand_tuple_literal,
3116                                                     /*dest_shape_index=*/{},
3117                                                     /*src_shape_index=*/{0}));
3118   return OkStatus();
3119 }
3120 
HandleCall(HloInstruction * call)3121 Status HloEvaluator::HandleCall(HloInstruction* call) {
3122   auto* computation = call->to_apply();
3123   auto operands = call->operands();
3124 
3125   std::vector<const Literal*> arg_literals;
3126   arg_literals.reserve(operands.size());
3127   for (auto operand : operands) {
3128     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
3129     arg_literals.push_back(&arg_literal);
3130   }
3131 
3132   std::unique_ptr<HloEvaluator> embedded_evaluator =
3133       CreateEmbedded(max_loop_iterations_);
3134   embedded_evaluator->set_dynamic_dimension_inference(
3135       dynamic_dimension_inference_);
3136   TF_ASSIGN_OR_RETURN(Literal result,
3137                       embedded_evaluator->Evaluate(*computation, arg_literals));
3138 
3139   evaluated_[call] = std::move(result);
3140   return OkStatus();
3141 }
3142 
HandleFusion(HloInstruction * fusion)3143 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
3144   HloModuleConfig config;
3145   // Attach cloned computation to an empty HLO module so the existing ones are
3146   // not modified.
3147   HloModule empty_hlo_module("EmptyModuleForFusion", config);
3148   HloCloneContext context(&empty_hlo_module);
3149   auto cloned_fused_computation =
3150       fusion->fused_instructions_computation()->Clone(
3151           /*suffix=*/"clone_with_layout", &context);
3152   for (auto* instruction : cloned_fused_computation->instructions()) {
3153     if (!LayoutUtil::HasLayout(instruction->shape())) {
3154       LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
3155     }
3156   }
3157   auto readded_computation =
3158       empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
3159 
3160   auto operands = fusion->operands();
3161   std::vector<const Literal*> arg_literals;
3162   arg_literals.reserve(operands.size());
3163   for (auto operand : operands) {
3164     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
3165     arg_literals.push_back(&arg_literal);
3166   }
3167 
3168   std::unique_ptr<HloEvaluator> embedded_evaluator =
3169       CreateEmbedded(max_loop_iterations_);
3170   embedded_evaluator->set_dynamic_dimension_inference(
3171       dynamic_dimension_inference_);
3172   TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator->Evaluate(
3173                                           *readded_computation, arg_literals));
3174 
3175   evaluated_[fusion] = std::move(result);
3176   return OkStatus();
3177 }
3178 
HandleConditional(HloInstruction * conditional)3179 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
3180   const auto& branch_index_literal =
3181       GetEvaluatedLiteralFor(conditional->operand(0));
3182   int branch_index;
3183   if (conditional->operand(0)->shape().element_type() == PRED) {
3184     branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
3185   } else {
3186     branch_index = branch_index_literal.Get<int32_t>({});
3187     if (branch_index < 0 || branch_index >= conditional->branch_count()) {
3188       branch_index = conditional->branch_count() - 1;
3189     }
3190   }
3191   const auto& branch_computation_arg =
3192       GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
3193 
3194   std::unique_ptr<HloEvaluator> embedded_evaluator =
3195       CreateEmbedded(max_loop_iterations_);
3196   embedded_evaluator->set_dynamic_dimension_inference(
3197       dynamic_dimension_inference_);
3198   TF_ASSIGN_OR_RETURN(Literal result,
3199                       embedded_evaluator->Evaluate(
3200                           *conditional->branch_computation(branch_index),
3201                           {&branch_computation_arg}));
3202 
3203   evaluated_[conditional] = std::move(result);
3204   return OkStatus();
3205 }
3206 
HandleSelect(HloInstruction * select)3207 Status HloEvaluator::HandleSelect(HloInstruction* select) {
3208   const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
3209   const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
3210   const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
3211 
3212   // If predicate is of scalar type, no element-wise selection would be needed.
3213   if (ShapeUtil::IsScalar(pred.shape())) {
3214     if (pred.Get<bool>({})) {
3215       evaluated_[select] = on_true.Clone();
3216     } else {
3217       evaluated_[select] = on_false.Clone();
3218     }
3219     return OkStatus();
3220   }
3221 
3222   return DefaultAction(select);
3223 }
3224 
3225 namespace {
3226 
CreateScalarLiteral(int64_t value,PrimitiveType element_type)3227 StatusOr<Literal> CreateScalarLiteral(int64_t value,
3228                                       PrimitiveType element_type) {
3229   Literal result;
3230   switch (element_type) {
3231     case S8:
3232       result = LiteralUtil::CreateR0(static_cast<int8_t>(value));
3233       break;
3234     case U8:
3235       result = LiteralUtil::CreateR0(static_cast<uint8_t>(value));
3236       break;
3237     case S16:
3238       result = LiteralUtil::CreateR0(static_cast<int16_t>(value));
3239       break;
3240     case U16:
3241       result = LiteralUtil::CreateR0(static_cast<uint16_t>(value));
3242       break;
3243     case S32:
3244       result = LiteralUtil::CreateR0(static_cast<int32_t>(value));
3245       break;
3246     case U32:
3247       result = LiteralUtil::CreateR0(static_cast<uint32_t>(value));
3248       break;
3249     case S64:
3250       result = LiteralUtil::CreateR0(static_cast<int64_t>(value));
3251       break;
3252     case U64:
3253       result = LiteralUtil::CreateR0(static_cast<uint64_t>(value));
3254       break;
3255     default:
3256       return InvalidArgument("Unsupported element type.");
3257   }
3258   return result;
3259 }
3260 
3261 // Parses the while loop if it matches one of the known patterns. Returns the
3262 // value of the loop induction variable after the loop execution if the loop is
3263 // static.
TryParseAndEvaluateWhileInductionVar(HloInstruction * while_hlo)3264 StatusOr<Literal> TryParseAndEvaluateWhileInductionVar(
3265     HloInstruction* while_hlo) {
3266   std::optional<ParsedWhileLoop> parsed_while_loop =
3267       PatternMatchParseWhileLoop(while_hlo);
3268   if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic()) {
3269     return FailedPrecondition(
3270         "Cannot evaluate a while loop's induction variable since the loop "
3271         "does not match a known loop pattern or the loop is not static.");
3272   }
3273   int64_t induction_var_value =
3274       parsed_while_loop->static_while_loop->induction_var_init_value +
3275       parsed_while_loop->static_while_loop->trip_count *
3276           parsed_while_loop->static_while_loop->step_size;
3277   Shape result_shape = while_hlo->shape().tuple_shapes(
3278       parsed_while_loop->static_while_loop->induction_var_index);
3279   TF_ASSIGN_OR_RETURN(
3280       Literal result,
3281       CreateScalarLiteral(induction_var_value, result_shape.element_type()));
3282   std::vector<Literal*> while_result_element_ptrs;
3283   while_result_element_ptrs.reserve(while_hlo->shape().tuple_shapes_size());
3284   std::vector<Literal> while_result_elements(
3285       while_hlo->shape().tuple_shapes_size());
3286   for (int i = 0; i < while_hlo->shape().tuple_shapes_size(); ++i) {
3287     if (i == parsed_while_loop->static_while_loop->induction_var_index) {
3288       while_result_element_ptrs.push_back(&result);
3289     } else {
3290       const Shape& shape = while_hlo->shape().tuple_shapes(i);
3291       while_result_elements[i] =
3292           Literal::CreateFromShapeWithUnknownLeafArrays(shape);
3293     }
3294   }
3295   return LiteralUtil::MakeTuple(while_result_element_ptrs);
3296 }
3297 
3298 }  // namespace
3299 
HandleWhile(HloInstruction * while_hlo)3300 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
3301   HloComputation* cond_comp = while_hlo->while_condition();
3302   HloComputation* body_comp = while_hlo->while_body();
3303   // Initialize the loop carried valued with the input to the While instruction.
3304   auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
3305   if (!lcv.IsKnown()) {
3306     std::optional<ParsedWhileLoop> parsed_while_loop =
3307         PatternMatchParseWhileLoop(while_hlo);
3308     evaluated_[while_hlo] =
3309         Literal::CreateFromShapeWithUnknownLeafArrays(while_hlo->shape());
3310     if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic() ||
3311         visitor_shape_index_.size() != 1 ||
3312         parsed_while_loop->static_while_loop->induction_var_index !=
3313             visitor_shape_index_[0]) {
3314       return OkStatus();
3315     }
3316     Shape induction_var_shape =
3317         ShapeUtil::GetSubshape(while_hlo->shape(), visitor_shape_index_);
3318     int64_t trip_count = parsed_while_loop->static_while_loop->trip_count;
3319     TF_ASSIGN_OR_RETURN(
3320         Literal induction_var_val,
3321         CreateScalarLiteral(trip_count, induction_var_shape.element_type()));
3322     TF_RETURN_IF_ERROR(evaluated_[while_hlo].CopyFrom(
3323         induction_var_val, /*dest_shape_index=*/visitor_shape_index_,
3324         /*src_shape_index=*/{}));
3325     return OkStatus();
3326   }
3327   bool keep_going = true;
3328   int64_t iteration_count = 0;
3329   std::unique_ptr<HloEvaluator> cond_evaluator =
3330       CreateEmbedded(max_loop_iterations_);
3331   cond_evaluator->set_dynamic_dimension_inference(dynamic_dimension_inference_);
3332   std::unique_ptr<HloEvaluator> loop_body_evaluator =
3333       CreateEmbedded(max_loop_iterations_);
3334   loop_body_evaluator->set_dynamic_dimension_inference(
3335       dynamic_dimension_inference_);
3336   while (keep_going) {
3337     if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
3338       StatusOr<Literal> result =
3339           TryParseAndEvaluateWhileInductionVar(while_hlo);
3340       if (result.ok()) {
3341         lcv = std::move(result).value();
3342         break;
3343       } else {
3344         return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
3345                                while_hlo->name(), max_loop_iterations_);
3346       }
3347     }
3348     TF_ASSIGN_OR_RETURN(auto cond_val,
3349                         cond_evaluator->Evaluate(*cond_comp, {&lcv}));
3350     keep_going = cond_val.GetFirstElement<bool>();
3351     if (keep_going) {
3352       TF_ASSIGN_OR_RETURN(auto body_val,
3353                           loop_body_evaluator->Evaluate(*body_comp, {&lcv}));
3354       VLOG(3) << "Loop iteration result: " << body_val.ToString();
3355       lcv = std::move(body_val);
3356       cond_evaluator->ResetVisitStates();
3357       loop_body_evaluator->ResetVisitStates();
3358     }
3359   }
3360   evaluated_[while_hlo] = std::move(lcv);
3361   return OkStatus();
3362 }
3363 
3364 namespace {
3365 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64_t const> indices,bool extract_as_scalar)3366 Literal ExtractLiteralFromIndexPositions(const Literal& from,
3367                                          absl::Span<int64_t const> indices,
3368                                          bool extract_as_scalar) {
3369   if (extract_as_scalar) {
3370     return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
3371   }
3372   // We use a InlinedVector here because we need to convert it to an
3373   // absl::Span later, and this would not work with std::vector<bool>.
3374   absl::InlinedVector<NativeT, 10> values;
3375   for (int64_t index : indices) {
3376     values.push_back(from.Get<NativeT>({index}));
3377   }
3378   return LiteralUtil::CreateR1<NativeT>(values);
3379 }
3380 
ExtractFromIndexPositions(const Literal & from,absl::Span<int64_t const> indices,bool extract_as_scalar=false)3381 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
3382                                             absl::Span<int64_t const> indices,
3383                                             bool extract_as_scalar = false) {
3384   if (extract_as_scalar) {
3385     CHECK_EQ(indices.size(), 1);
3386   }
3387   PrimitiveType type = from.shape().element_type();
3388   switch (type) {
3389     case PRED: {
3390       return ExtractLiteralFromIndexPositions<bool>(from, indices,
3391                                                     extract_as_scalar);
3392     }
3393     case U8: {
3394       return ExtractLiteralFromIndexPositions<uint8_t>(from, indices,
3395                                                        extract_as_scalar);
3396     }
3397     case S8: {
3398       return ExtractLiteralFromIndexPositions<int8_t>(from, indices,
3399                                                       extract_as_scalar);
3400     }
3401     case BF16: {
3402       return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
3403                                                         extract_as_scalar);
3404     }
3405     case F16: {
3406       return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
3407                                                            extract_as_scalar);
3408     }
3409     case U16: {
3410       return ExtractLiteralFromIndexPositions<uint16_t>(from, indices,
3411                                                         extract_as_scalar);
3412     }
3413     case S16: {
3414       return ExtractLiteralFromIndexPositions<int16_t>(from, indices,
3415                                                        extract_as_scalar);
3416     }
3417     case F32: {
3418       return ExtractLiteralFromIndexPositions<float>(from, indices,
3419                                                      extract_as_scalar);
3420     }
3421     case U32: {
3422       return ExtractLiteralFromIndexPositions<uint32_t>(from, indices,
3423                                                         extract_as_scalar);
3424     }
3425     case S32: {
3426       return ExtractLiteralFromIndexPositions<int32_t>(from, indices,
3427                                                        extract_as_scalar);
3428     }
3429     case F64: {
3430       return ExtractLiteralFromIndexPositions<double>(from, indices,
3431                                                       extract_as_scalar);
3432     }
3433     case C64: {
3434       return ExtractLiteralFromIndexPositions<std::complex<float>>(
3435           from, indices, extract_as_scalar);
3436     }
3437     case U64: {
3438       return ExtractLiteralFromIndexPositions<uint64_t>(from, indices,
3439                                                         extract_as_scalar);
3440     }
3441     case S64: {
3442       return ExtractLiteralFromIndexPositions<int64_t>(from, indices,
3443                                                        extract_as_scalar);
3444     }
3445     case C128: {
3446       return ExtractLiteralFromIndexPositions<std::complex<double>>(
3447           from, indices, extract_as_scalar);
3448     }
3449     default:
3450       return InvalidArgument("Unsupported type for Sort: %s",
3451                              PrimitiveType_Name(type));
3452   }
3453 }
3454 }  // namespace
3455 
HandleSort(HloInstruction * sort)3456 Status HloEvaluator::HandleSort(HloInstruction* sort) {
3457   TF_RET_CHECK(sort->operand_count() >= 1)
3458       << "Expected at least 1 operand for sort";
3459   for (int64_t i = 1; i < sort->operand_count(); ++i) {
3460     TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
3461                                            sort->operand(i)->shape()))
3462         << "All Sort operands must have the same dimensions";
3463   }
3464 
3465   if (VLOG_IS_ON(3)) {
3466     for (int64_t i = 0; i < sort->operand_count(); ++i) {
3467       VLOG(3) << "HandleSort operand " << i << " literal: "
3468               << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
3469     }
3470   }
3471   Shape key_shape = sort->operand(0)->shape();
3472   auto rank = key_shape.rank();
3473   std::vector<Literal> result_literals;
3474   result_literals.reserve(sort->operand_count());
3475   for (int64_t i = 0; i < sort->operand_count(); ++i) {
3476     result_literals.emplace_back(sort->operand(i)->shape());
3477   }
3478   std::vector<int64_t> zero_base(rank, 0);
3479   std::vector<int64_t> increment(rank, 1);
3480   int64_t sort_dim = sort->dimensions(0);
3481   int64_t sort_dim_elements = key_shape.dimensions(sort_dim);
3482   TF_RET_CHECK(sort_dim >= 0 && sort_dim < increment.size())
3483       << "Unexpected out-of-bound sort dimension " << sort_dim
3484       << " accessing increment of size " << increment.size();
3485   increment[sort_dim] = sort_dim_elements;
3486   std::unique_ptr<HloEvaluator> embedded_evaluator =
3487       CreateEmbedded(max_loop_iterations_);
3488   // Iterate through each dimension except 'sort_dim'.
3489   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3490       key_shape, zero_base, key_shape.dimensions(), increment,
3491       [&](absl::Span<const int64_t> indices) -> StatusOr<bool> {
3492         // Extract a slice from each operand literal that corresponds to
3493         // exactly the row in dimension 'sort_dim'.
3494         std::vector<int64_t> limit_indices(indices.begin(), indices.end());
3495         absl::c_for_each(limit_indices, [](int64_t& index) { ++index; });
3496         limit_indices[sort_dim] = sort_dim_elements;
3497         std::vector<Literal> literals_to_sort;
3498         literals_to_sort.reserve(sort->operand_count());
3499         for (int64_t i = 0; i < sort->operand_count(); ++i) {
3500           TF_ASSIGN_OR_RETURN(auto literal_to_sort,
3501                               GetEvaluatedLiteralFor(sort->operand(i))
3502                                   .Slice(indices, limit_indices)
3503                                   .Reshape({sort_dim_elements}));
3504           literals_to_sort.push_back(std::move(literal_to_sort));
3505         }
3506         std::vector<int64_t> indices_to_sort(sort_dim_elements);
3507         std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
3508         Status compare_status = OkStatus();
3509         auto comparator = [sort, &compare_status,
3510                            embedded_evaluator = embedded_evaluator.get(),
3511                            &literals_to_sort](int64_t a, int64_t b) {
3512           std::vector<Literal> literals;
3513           literals.reserve(2 * sort->operand_count());
3514           for (int64_t i = 0; i < sort->operand_count(); ++i) {
3515             auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
3516                                                  /*extract_as_scalar=*/true);
3517             if (!lhs.ok()) {
3518               compare_status = lhs.status();
3519               return false;
3520             }
3521             literals.push_back(std::move(lhs.ValueOrDie()));
3522             auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
3523                                                  /*extract_as_scalar=*/true);
3524             if (!rhs.ok()) {
3525               compare_status = rhs.status();
3526               return false;
3527             }
3528             literals.push_back(std::move(rhs.ValueOrDie()));
3529           }
3530           std::vector<const Literal*> literal_ptrs;
3531           absl::c_transform(literals, std::back_inserter(literal_ptrs),
3532                             [](const Literal& literal) { return &literal; });
3533 
3534           auto computed_result =
3535               embedded_evaluator->Evaluate(*sort->to_apply(), literal_ptrs);
3536           // Clear visit states so that we can use the evaluator again
3537           // on the same computation.
3538           embedded_evaluator->ResetVisitStates();
3539           if (!computed_result.ok()) {
3540             compare_status = computed_result.status();
3541             return false;
3542           }
3543           return computed_result.ValueOrDie().Get<bool>({});
3544         };
3545         if (Cast<HloSortInstruction>(sort)->is_stable()) {
3546           std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
3547                            comparator);
3548         } else {
3549           std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
3550         }
3551         if (!compare_status.ok()) {
3552           return compare_status;
3553         }
3554         std::vector<int64_t> slice_dimensions(rank, 1);
3555         slice_dimensions[sort_dim] = sort_dim_elements;
3556         std::vector<int64_t> start_indices(rank, 0);
3557         for (int64_t i = 0; i < sort->operand_count(); ++i) {
3558           TF_ASSIGN_OR_RETURN(
3559               Literal sorted_literal,
3560               ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
3561           TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
3562                               sorted_literal.Reshape(slice_dimensions));
3563           TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
3564               sorted_literal_reshaped, start_indices, indices,
3565               slice_dimensions));
3566         }
3567         return true;
3568       }));
3569 
3570   if (sort->operand_count() == 1) {
3571     evaluated_[sort] = std::move(result_literals[0]);
3572   } else {
3573     std::vector<const Literal*> literal_ptrs;
3574     absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
3575                       [](const Literal& literal) { return &literal; });
3576 
3577     Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
3578     VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
3579 
3580     evaluated_[sort] = std::move(result_tuple);
3581   }
3582   return OkStatus();
3583 }
3584 
IsScalarAdd(HloComputation * computation)3585 static bool IsScalarAdd(HloComputation* computation) {
3586   HloInstruction* instruction = computation->root_instruction();
3587   if (instruction->opcode() == HloOpcode::kAdd &&
3588       computation->num_parameters() == 2) {
3589     const HloInstruction* lhs = instruction->operand(0);
3590     const HloInstruction* rhs = instruction->operand(1);
3591     return lhs->opcode() == HloOpcode::kParameter &&
3592            ShapeUtil::IsScalar(lhs->shape()) &&
3593            rhs->opcode() == HloOpcode::kParameter &&
3594            ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
3595   }
3596   return false;
3597 }
3598 
3599 // Run a single step of an inner loop while running reduction, which applies
3600 // the user-provided computation on the accumulator and the output element
3601 // (until the reduction is completed, the output element is also used as
3602 // an accumulator).
PerformReductionStep(bool is_tuple,absl::Span<const int64_t> input_index,absl::Span<const int64_t> output_index,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * computation,HloEvaluator * embedded_evaluator)3603 static StatusOr<bool> PerformReductionStep(
3604     bool is_tuple, absl::Span<const int64_t> input_index,
3605     absl::Span<const int64_t> output_index,
3606     absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
3607     HloComputation* computation, HloEvaluator* embedded_evaluator) {
3608   int num_args = results.size();
3609 
3610   absl::InlinedVector<Literal, 1> arg_values;
3611   arg_values.reserve(num_args);
3612   absl::InlinedVector<Literal, 1> accumulators;
3613   accumulators.reserve(num_args);
3614   for (int64_t i = 0; i < num_args; ++i) {
3615     arg_values.emplace_back(
3616         ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
3617     accumulators.emplace_back(
3618         ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
3619 
3620     TF_RETURN_IF_ERROR(
3621         arg_values[i].CopyElementFrom(*input_args[i], input_index, {}));
3622     TF_RETURN_IF_ERROR(
3623         accumulators[i].CopyElementFrom(results[i], output_index, {}));
3624   }
3625 
3626   // Evaluate computation with specified literal operands.
3627   absl::InlinedVector<Literal*, 2> embedded_operands;
3628   for (Literal& accumulator : accumulators) {
3629     embedded_operands.push_back(&accumulator);
3630   }
3631   for (Literal& local_input : arg_values) {
3632     embedded_operands.push_back(&local_input);
3633   }
3634 
3635   TF_ASSIGN_OR_RETURN(
3636       Literal computed_result,
3637       embedded_evaluator->Evaluate(*computation, embedded_operands));
3638 
3639   // Clear visit states so that we can use the evaluator again on the same
3640   // computation.
3641   embedded_evaluator->ResetVisitStates();
3642 
3643   if (is_tuple) {
3644     std::vector<Literal> computed_results = computed_result.DecomposeTuple();
3645     for (int64_t i = 0; i < num_args; ++i) {
3646       TF_RETURN_IF_ERROR(
3647           results[i].CopyElementFrom(computed_results[i], {}, output_index));
3648     }
3649   } else {
3650     TF_RETURN_IF_ERROR(
3651         results[0].CopyElementFrom(computed_result, {}, output_index));
3652   }
3653 
3654   return true;
3655 }
3656 
GenerateReduceOutputElement(bool is_tuple,absl::Span<const int64_t> output_index,absl::Span<const Literal * const> init_values,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * function,HloEvaluator * embedded_evaluator,absl::Span<const int64_t> arg_dim_steps,absl::Span<const int64_t> arg_dim_counts,absl::Span<const int64_t> result_to_arg_index)3657 static StatusOr<bool> GenerateReduceOutputElement(
3658     bool is_tuple, absl::Span<const int64_t> output_index,
3659 
3660     absl::Span<const Literal* const> init_values,
3661     absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
3662 
3663     HloComputation* function, HloEvaluator* embedded_evaluator,
3664 
3665     absl::Span<const int64_t> arg_dim_steps,
3666     absl::Span<const int64_t> arg_dim_counts,
3667     absl::Span<const int64_t> result_to_arg_index) {
3668   bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) &&
3669                       IsScalarAdd(function) && !is_tuple;
3670 
3671   const Shape& arg_shape = input_args[0]->shape();
3672   absl::Span<const int64_t> arg_dimensions = arg_shape.dimensions();
3673   std::vector<int64_t> base(arg_dimensions.size());
3674   for (int64_t i = 0; i < output_index.size(); ++i) {
3675     base[result_to_arg_index[i]] = output_index[i];
3676   }
3677 
3678   for (int64_t i = 0; i < results.size(); ++i) {
3679     TF_RETURN_IF_ERROR(
3680         results[i].CopyElementFrom(*init_values[i], {}, output_index));
3681   }
3682 
3683   if (use_fast_add) {
3684     double computed_result = *init_values[0]->GetAsDouble({});
3685     auto reduction_step =
3686         [&](absl::Span<const int64_t> input_index) -> StatusOr<bool> {
3687       double argument = *input_args[0]->GetAsDouble(input_index);
3688       computed_result += argument;
3689       return true;
3690     };
3691     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3692         arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step));
3693     TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result));
3694     return true;
3695   }
3696 
3697   // Iterates only over reduced shape, as counts and steps are set to zero
3698   // for all non-reduced dimensions.
3699   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3700       arg_shape, base, arg_dim_counts, arg_dim_steps,
3701       [&](absl::Span<const int64_t> input_index) {
3702         return PerformReductionStep(is_tuple, input_index, output_index,
3703                                     input_args, results, function,
3704                                     embedded_evaluator);
3705       }));
3706   return true;
3707 }
3708 
HandleReduce(HloInstruction * instr)3709 Status HloEvaluator::HandleReduce(HloInstruction* instr) {
3710   HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
3711   int64_t num_args = reduce->inputs().size();
3712   absl::Span<const int64_t> dimensions_to_reduce(reduce->dimensions());
3713   HloComputation* function = reduce->to_apply();
3714 
3715   absl::InlinedVector<const Shape*, 1> operand_shapes;
3716   for (const HloInstruction* operand : reduce->operands()) {
3717     operand_shapes.push_back(&operand->shape());
3718   }
3719   TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
3720                       ShapeInference::InferReduceShape(
3721                           operand_shapes, dimensions_to_reduce,
3722                           /*to_apply=*/function->ComputeProgramShape()));
3723   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(),
3724                                                         inferred_return_shape))
3725       << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
3726       << " but is inferred to be: "
3727       << ShapeUtil::HumanString(inferred_return_shape);
3728 
3729   absl::InlinedVector<const Literal*, 1> input_args(num_args);
3730   absl::InlinedVector<const Literal*, 1> init_values(num_args);
3731   for (int64_t i = 0; i < num_args; ++i) {
3732     input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]);
3733     VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString();
3734     init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]);
3735     VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString();
3736     TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape()));
3737   }
3738 
3739   // All args and results have the same dimensions, so pick an arbitrary one.
3740   const Shape& arg_shape = input_args[0]->shape();
3741   const Shape& out_shape = inferred_return_shape;
3742   bool is_tuple = out_shape.IsTuple();
3743   const Shape& output_shape = inferred_return_shape.IsTuple()
3744                                   ? inferred_return_shape.tuple_shapes(0)
3745                                   : inferred_return_shape;
3746 
3747   absl::Span<const int64_t> arg_dimensions = arg_shape.dimensions();
3748 
3749   // All increments are set to 0.
3750   std::vector<int64_t> arg_dim_steps(arg_dimensions.size());
3751 
3752   // All counts are set to 0.
3753   std::vector<int64_t> arg_dim_counts(arg_dimensions.size());
3754 
3755   // Set steps and counts for reduced dimensions.
3756   // This avoids iterating over non-reduced dimensions, as their step
3757   // and count is set to zero.
3758   for (const int64_t dim : dimensions_to_reduce) {
3759     arg_dim_steps[dim] = 1;
3760     arg_dim_counts[dim] = arg_dimensions[dim];
3761   }
3762 
3763   // Map each dimension in the result to a dimension in arg that isn't
3764   // being reduced.
3765   std::vector<int64_t> result_to_arg_index;
3766   for (int64_t i = 0; i < arg_dimensions.size(); ++i) {
3767     if (arg_dim_steps[i] == 0) {
3768       result_to_arg_index.push_back(i);
3769     }
3770   }
3771 
3772   std::unique_ptr<HloEvaluator> embedded_evaluator =
3773       CreateEmbedded(max_loop_iterations_);
3774   absl::InlinedVector<Literal, 1> results(num_args);
3775   for (int64_t i = 0; i < num_args; ++i) {
3776     results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape);
3777   }
3778 
3779   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3780       output_shape, [&](absl::Span<const int64_t> output_index) {
3781         return GenerateReduceOutputElement(
3782             is_tuple, output_index, init_values, input_args,
3783             absl::Span<Literal>(results), function, embedded_evaluator.get(),
3784             arg_dim_steps, arg_dim_counts, result_to_arg_index);
3785       }));
3786 
3787   if (is_tuple) {
3788     Literal tuple_result(inferred_return_shape);
3789     for (int64_t i = 0; i < num_args; ++i) {
3790       TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
3791     }
3792     evaluated_[reduce] = std::move(tuple_result);
3793   } else {
3794     CHECK_EQ(results.size(), 1);
3795     evaluated_[reduce] = std::move(results[0]);
3796   }
3797   if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) {
3798     TF_ASSIGN_OR_RETURN(evaluated_[reduce],
3799                         evaluated_[reduce].ConvertToShape(reduce->shape()));
3800   }
3801   return OkStatus();
3802 }
3803 
HandleReduceWindow(HloInstruction * hlo)3804 Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
3805   // Here we delegate the handling to the typed visitor class, instantiated by
3806   // using the type of the first input of ReduceWindow. The support for the
3807   // variadic case inside the typed_visitor is made to not use the template
3808   // parameter so it doesn't really matter which type is used to instantiate it
3809   // here. We choose not to move the implementation for handle ReduceWindow
3810   // from the typed visitor to here because we need to reuse the
3811   // IterateThroughWindow method, which is defined and only avaiable inside the
3812   // typed visitor.
3813   if (hlo->shape().IsTuple()) {
3814     return hlo->Visit(
3815         typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
3816   } else {
3817     return DefaultAction(hlo);
3818   }
3819 }
3820 
HandleCustomCall(HloInstruction * custom_call)3821 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
3822   if (!custom_call_handler_) {
3823     // No handler is registered; this means custom-calls are not allowed.
3824     return DefaultAction(custom_call);
3825   }
3826 
3827   // Evaluate input operands so the handler has access to the operand data.
3828   std::vector<const Literal*> operands;
3829   operands.reserve(custom_call->operand_count());
3830   for (const HloInstruction* operand : custom_call->operands()) {
3831     operands.push_back(&GetEvaluatedLiteralFor(operand));
3832   }
3833 
3834   // Synchronously issue the handler to populate the instruction output literal.
3835   TF_ASSIGN_OR_RETURN(
3836       auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
3837 
3838   evaluated_[custom_call] = std::move(output);
3839   return OkStatus();
3840 }
3841 
Preprocess(HloInstruction * hlo)3842 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
3843   VLOG(3) << "About to visit HLO: " << hlo->ToString();
3844   if (!enable_partial_evaluation_) {
3845     for (HloInstruction* operand : hlo->mutable_operands()) {
3846       if (!IsAlreadyEvaluated(operand) ||
3847           !GetEvaluatedLiteralFor(operand).IsKnown()) {
3848         return tensorflow::errors::FailedPrecondition(
3849             "Failed to evaluate instruction since its operands are unknown "
3850             "or undetermined and partial evaluation is not enabled.");
3851       }
3852     }
3853   }
3854   return ShapeUtil::ValidateShape(hlo->shape());
3855 }
3856 
Postprocess(HloInstruction * hlo)3857 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
3858   VLOG(3) << "Finished visiting " << hlo->ToString()
3859           << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
3860   // Out of convenience the literal may have been produced with a different
3861   // layout. Relayout as indicated by the HLO instruction.
3862   auto evaluated_shape = GetEvaluatedLiteralFor(hlo).shape();
3863   xla::Shape hlo_shape = hlo->shape();
3864   if (hlo_shape.IsArray() && !hlo_shape.has_layout()) {
3865     *hlo_shape.mutable_layout() =
3866         LayoutUtil::GetDefaultLayoutForShape(hlo_shape);
3867   }
3868   if (evaluated_shape.has_layout() && hlo_shape.has_layout() &&
3869       !Layout::Equal().MinorToMajorOnly()(evaluated_shape.layout(),
3870                                           hlo_shape.layout())) {
3871     evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo_shape);
3872   }
3873   return OkStatus();
3874 }
3875 
3876 namespace {
3877 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)> & impl_fn)3878 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
3879     const Array2D<T>& lhs, const Array2D<T>& rhs,
3880     const std::function<void(const void* run_options_ptr, T* out, T* lhs,
3881                              T* rhs, int64_t m, int64_t n, int64_t k,
3882                              int32_t transpose_lhs, int32_t transpose_rhs)>&
3883         impl_fn) {
3884   CHECK_EQ(lhs.width(), rhs.height());
3885   int m = lhs.height();
3886   int n = rhs.width();
3887   int k = lhs.width();
3888   auto result = std::make_unique<Array2D<T>>(m, n);
3889   // Because Eigen is a header-oriented library, make sure that the Eigen code
3890   // is the same as the code used by the CPU backend (otherwise the linker will
3891   // randomly pick *some* definition).
3892   impl_fn(
3893       /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
3894       k,
3895       /*transpose_lhs=*/0,
3896       /*transpose_rhs=*/0);
3897   return result;
3898 }
3899 }  // namespace
3900 
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)3901 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
3902     const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
3903   return MatmulArray2DImpl<Eigen::half>(
3904       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
3905 }
3906 
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)3907 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
3908     const Array2D<float>& lhs, const Array2D<float>& rhs) {
3909   return MatmulArray2DImpl<float>(
3910       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
3911 }
3912 
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)3913 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
3914     const Array2D<double>& lhs, const Array2D<double>& rhs) {
3915   return MatmulArray2DImpl<double>(
3916       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
3917 }
3918 
MatmulArray2D(const Array2D<std::complex<float>> & lhs,const Array2D<std::complex<float>> & rhs)3919 std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
3920     const Array2D<std::complex<float>>& lhs,
3921     const Array2D<std::complex<float>>& rhs) {
3922   return MatmulArray2DImpl<std::complex<float>>(
3923       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
3924 }
3925 
MatmulArray2D(const Array2D<std::complex<double>> & lhs,const Array2D<std::complex<double>> & rhs)3926 std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
3927     const Array2D<std::complex<double>>& lhs,
3928     const Array2D<std::complex<double>>& rhs) {
3929   return MatmulArray2DImpl<std::complex<double>>(
3930       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
3931 }
3932 
MatmulArray2D(const Array2D<int32_t> & lhs,const Array2D<int32_t> & rhs)3933 std::unique_ptr<Array2D<int32_t>> HloEvaluator::MatmulArray2D(
3934     const Array2D<int32_t>& lhs, const Array2D<int32_t>& rhs) {
3935   return MatmulArray2DImpl<int32_t>(
3936       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulS32);
3937 }
3938 
3939 }  // namespace xla
3940