1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <complex>
20 #include <cstdint>
21 #include <cstdlib>
22 #include <functional>
23 #include <iterator>
24 #include <memory>
25 #include <optional>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30
31 #include "absl/algorithm/container.h"
32 #include "absl/base/internal/endian.h"
33 #include "absl/cleanup/cleanup.h"
34 #include "absl/container/inlined_vector.h"
35 #include "absl/strings/match.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "tensorflow/compiler/xla/index_util.h"
39 #include "tensorflow/compiler/xla/layout_util.h"
40 #include "tensorflow/compiler/xla/literal.h"
41 #include "tensorflow/compiler/xla/literal_util.h"
42 #include "tensorflow/compiler/xla/map_util.h"
43 #include "tensorflow/compiler/xla/primitive_util.h"
44 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
45 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
46 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
47 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_query.h"
50 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
51 #include "tensorflow/compiler/xla/service/shape_inference.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/status_macros.h"
54 #include "tensorflow/compiler/xla/statusor.h"
55 #include "tensorflow/compiler/xla/types.h"
56 #include "tensorflow/compiler/xla/util.h"
57 #include "tensorflow/compiler/xla/window_util.h"
58 #include "tensorflow/core/lib/core/bitmap.h"
59 #include "tensorflow/core/lib/core/errors.h"
60 #include "tensorflow/core/lib/core/status.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/logging.h"
63 #include "tensorflow/core/platform/protobuf.h"
64 #include "tensorflow/core/platform/status.h"
65 #include "tensorflow/core/platform/statusor.h"
66 #include "tensorflow/core/platform/types.h"
67 #include "tensorflow/stream_executor/lib/statusor.h"
68
69 namespace xla {
70
71 namespace {
72
73 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)74 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
75 LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
76 std::function<bool(OperandT, OperandT)> compare_op;
77 switch (direction) {
78 case ComparisonDirection::kEq:
79 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
80 return lhs_el == rhs_el;
81 };
82 break;
83 case ComparisonDirection::kNe:
84 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
85 return lhs_el != rhs_el;
86 };
87 break;
88 case ComparisonDirection::kGe:
89 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
90 return lhs_el >= rhs_el;
91 };
92 break;
93 case ComparisonDirection::kGt:
94 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
95 return lhs_el > rhs_el;
96 };
97 break;
98 case ComparisonDirection::kLe:
99 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
100 return lhs_el <= rhs_el;
101 };
102 break;
103 case ComparisonDirection::kLt:
104 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
105 return lhs_el < rhs_el;
106 };
107 break;
108 }
109
110 Literal result(shape);
111 TF_RETURN_IF_ERROR(
112 result.Populate<bool>([&](absl::Span<const int64_t> multi_index) {
113 return compare_op(lhs_literal.Get<OperandT>(multi_index),
114 rhs_literal.Get<OperandT>(multi_index));
115 }));
116
117 return std::move(result);
118 }
119
120 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)121 StatusOr<Literal> Compare<complex64>(const Shape& shape,
122 ComparisonDirection direction,
123 LiteralSlice lhs_literal,
124 LiteralSlice rhs_literal) {
125 std::function<bool(complex64, complex64)> compare_op;
126 switch (direction) {
127 case ComparisonDirection::kEq:
128 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
129 return lhs_el == rhs_el;
130 };
131 break;
132 case ComparisonDirection::kNe:
133 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
134 return lhs_el != rhs_el;
135 };
136 break;
137 default:
138 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
139 << ComparisonDirectionToString(direction);
140 }
141
142 Literal result(shape);
143 TF_RETURN_IF_ERROR(
144 result.Populate<bool>([&](absl::Span<const int64_t> multi_index) {
145 return compare_op(lhs_literal.Get<complex64>(multi_index),
146 rhs_literal.Get<complex64>(multi_index));
147 }));
148
149 return std::move(result);
150 }
151
152 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)153 StatusOr<Literal> Compare<complex128>(const Shape& shape,
154 ComparisonDirection direction,
155 LiteralSlice lhs_literal,
156 LiteralSlice rhs_literal) {
157 std::function<bool(complex128, complex128)> compare_op;
158 switch (direction) {
159 case ComparisonDirection::kEq:
160 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
161 return lhs_el == rhs_el;
162 };
163 break;
164 case ComparisonDirection::kNe:
165 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
166 return lhs_el != rhs_el;
167 };
168 break;
169 default:
170 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
171 << ComparisonDirectionToString(direction);
172 }
173
174 Literal result(shape);
175 TF_RETURN_IF_ERROR(
176 result.Populate<bool>([&](absl::Span<const int64_t> multi_index) {
177 return compare_op(lhs_literal.Get<complex128>(multi_index),
178 rhs_literal.Get<complex128>(multi_index));
179 }));
180
181 return std::move(result);
182 }
183
184 // Represents an index into the while argument tuple and / or a value.
185 // At least one of param_index and value has a value; both of them could have
186 // a value.
187 struct ParamIndexAndValue {
188 std::optional<int64_t> param_index;
189 std::optional<int64_t> value;
190
IsValidxla::__anon9f8d6ac10111::ParamIndexAndValue191 bool IsValid() const { return param_index.has_value() || value.has_value(); }
192 };
193
194 // Represents the while loop condition comparison.
195 // We assume comparison is of the form: lhs comp rhs.
196 struct WhileCondComparison {
197 ComparisonDirection comparson_direction;
198 ParamIndexAndValue lhs;
199 ParamIndexAndValue rhs;
200 };
201
202 // Represents the parsed while loop condition. The loop induction variable may
203 // either be used in a comparison or returned directly, i.e., NoOp. In the case
204 // of NoOp, it contains the parameter index and initial value of the loop
205 // induction variable.
206 using WhileCondComparisonOrNoOp =
207 std::variant<WhileCondComparison, ParamIndexAndValue>;
208
209 // Finds the while loop condition comparison by matching the loop condition root
210 // with known patterns.
PatternMatchLoopCondComparison(HloInstruction * loop_cond_root)211 std::optional<WhileCondComparisonOrNoOp> PatternMatchLoopCondComparison(
212 HloInstruction* loop_cond_root) {
213 // Base pattern #1: gte-0 comp gte-1
214 if (Match(loop_cond_root,
215 match::Compare()
216 .WithOperand(0, match::GetTupleElement().WithOperand(
217 0, match::Parameter().WithParameterNum(0)))
218 .WithOperand(1,
219 match::GetTupleElement().WithOperand(
220 0, match::Parameter().WithParameterNum(0))))) {
221 return WhileCondComparison{
222 loop_cond_root->comparison_direction(),
223 {/*param_index=*/loop_cond_root->operand(0)->tuple_index()},
224 {/*param_index=*/loop_cond_root->operand(1)->tuple_index()}};
225 }
226 // Base pattern #2: constant comp gte
227 if (Match(loop_cond_root,
228 match::Compare()
229 .WithOperand(0, match::Constant())
230 .WithOperand(1,
231 match::GetTupleElement().WithOperand(
232 0, match::Parameter().WithParameterNum(0))))) {
233 std::optional<int64_t> lhs_value =
234 loop_cond_root->operand(0)->literal().GetFirstInteger();
235 if (!lhs_value.has_value()) {
236 return std::nullopt;
237 }
238 return WhileCondComparison{
239 loop_cond_root->comparison_direction(),
240 {/*param_index=*/std::nullopt, /*value=*/*lhs_value},
241 {/*param_index=*/loop_cond_root->operand(1)->tuple_index()}};
242 }
243 // Base pattern #3: gte comp constant
244 if (Match(loop_cond_root,
245 match::Compare()
246 .WithOperand(0, match::GetTupleElement().WithOperand(
247 0, match::Parameter().WithParameterNum(0)))
248 .WithOperand(1, match::Constant()))) {
249 std::optional<int64_t> rhs_value =
250 loop_cond_root->operand(1)->literal().GetFirstInteger();
251 if (!rhs_value.has_value()) {
252 return std::nullopt;
253 }
254 return WhileCondComparison{
255 loop_cond_root->comparison_direction(),
256 {/*param_index=*/loop_cond_root->operand(0)->tuple_index(),
257 /*value=*/std::nullopt},
258 {/*param_index=*/std::nullopt, /*value=*/*rhs_value},
259 };
260 }
261 // Base pattern #4: gte is a boolean scalar and it was return immediately.
262 if (Match(loop_cond_root, match::GetTupleElement().WithOperand(
263 0, match::Parameter().WithParameterNum(0)))) {
264 if (loop_cond_root->shape().element_type() != PrimitiveType::PRED &&
265 loop_cond_root->shape().rank() != 0) {
266 return std::nullopt;
267 }
268 return ParamIndexAndValue{{/*param_index=*/loop_cond_root->tuple_index()}};
269 }
270
271 // Recursive pattern #1:
272 // loop_cond_root is a GetTupleElement whose operand is a call with a single
273 // parameter which takes the computation's single parameter.
274 // In this case, if the called computation's root is a tuple, we can recurse
275 // on that tuple's element as the new loop_cond_root.
276 if (Match(loop_cond_root,
277 match::GetTupleElement().WithOperand(
278 0, match::Call().WithNumOperands(1).WithOperand(
279 0, match::Parameter().WithParameterNum(0))))) {
280 HloInstruction* call_instruction = loop_cond_root->mutable_operand(0);
281 HloComputation* to_apply = call_instruction->to_apply();
282 HloInstruction* to_apply_root = to_apply->root_instruction();
283 if (Match(to_apply_root, match::Tuple())) {
284 return PatternMatchLoopCondComparison(
285 to_apply_root->mutable_operand(loop_cond_root->tuple_index()));
286 }
287 }
288 // Recursive pattern #2:
289 // loop_cond_root is a GetTupleElement whose operand is a tuple.
290 // We can recurse on the tuple's element as the new loop_cond_root.
291 if (Match(loop_cond_root,
292 match::GetTupleElement().WithOperand(0, match::Tuple()))) {
293 HloInstruction* new_cond_root =
294 loop_cond_root->mutable_operand(0)->mutable_operand(
295 loop_cond_root->tuple_index());
296 return PatternMatchLoopCondComparison(new_cond_root);
297 }
298 return std::nullopt;
299 }
300
301 // Tries to parse the loop body to find how the induction variable is updated
302 // using pattern matching.
PatternMatchInductionVarUpdate(HloInstruction * loop_body_root,int64_t tuple_index)303 std::optional<int64_t> PatternMatchInductionVarUpdate(
304 HloInstruction* loop_body_root, int64_t tuple_index) {
305 // Pattern #1: induc_var = induc_var + constant
306 if (Match(loop_body_root,
307 match::Tuple().WithOperand(
308 tuple_index,
309 match::Add()
310 .WithOperand(0, match::GetTupleElement()
311 .WithTupleIndex(tuple_index)
312 .WithOperand(0, match::Parameter()))
313 .WithOperand(1, match::Constant())))) {
314 std::optional<int64_t> step_size = loop_body_root->operand(tuple_index)
315 ->operand(1)
316 ->literal()
317 .GetFirstInteger();
318 if (!step_size.has_value()) {
319 return std::nullopt;
320 }
321 return *step_size;
322 }
323 // Pattern #2: induc_var = constant + induc_var
324 if (Match(
325 loop_body_root,
326 match::Tuple().WithOperand(
327 tuple_index,
328 match::Add()
329 .WithOperand(0, match::Constant())
330 .WithOperand(1, match::GetTupleElement()
331 .WithTupleIndex(tuple_index)
332 .WithOperand(0, match::Parameter()))))) {
333 std::optional<int64_t> step_size = loop_body_root->operand(tuple_index)
334 ->operand(0)
335 ->literal()
336 .GetFirstInteger();
337 if (!step_size.has_value()) {
338 return std::nullopt;
339 }
340 return *step_size;
341 }
342
343 // Pattern #3: induc_var = induc_var - constant
344 if (Match(loop_body_root,
345 match::Tuple().WithOperand(
346 tuple_index,
347 match::Subtract()
348 .WithOperand(0, match::GetTupleElement()
349 .WithTupleIndex(tuple_index)
350 .WithOperand(0, match::Parameter()))
351 .WithOperand(1, match::Constant())))) {
352 std::optional<int64_t> step_size = loop_body_root->operand(tuple_index)
353 ->operand(1)
354 ->literal()
355 .GetFirstInteger();
356 if (!step_size.has_value()) {
357 return std::nullopt;
358 }
359 return -*step_size;
360 }
361
362 // Pattern #4: the induc_var is directly returned from the loop body with
363 // no changes.
364 if (Match(loop_body_root,
365 match::Tuple().WithOperand(
366 tuple_index,
367 match::GetTupleElement()
368 .WithOperand(0, match::Parameter().WithParameterNum(0))
369 .WithTupleIndex(tuple_index)))) {
370 return 0;
371 }
372 return std::nullopt;
373 }
374
PatternMatchLoopCondVarOverride(HloInstruction * loop_body_root,int64_t tuple_index)375 std::optional<bool> PatternMatchLoopCondVarOverride(
376 HloInstruction* loop_body_root, int64_t tuple_index) {
377 if (Match(loop_body_root, match::Tuple()) &&
378 loop_body_root->operand_count() > tuple_index) {
379 HloInstruction* cond_var_override =
380 loop_body_root->mutable_operand(tuple_index);
381 HloEvaluator evaluator;
382 StatusOr<Literal> new_cond_var = evaluator.Evaluate(
383 cond_var_override, /*recursively_evaluate_nonconstant_operands=*/true);
384 if (new_cond_var.ok()) {
385 return new_cond_var->GetFirstElement<bool>();
386 }
387 }
388 return std::nullopt;
389 }
390
391 // Repesents a value that might or might not be determined statically.
392 struct DynamicOrStaticValue {
393 std::optional<int64_t> static_value;
is_dynamicxla::__anon9f8d6ac10111::DynamicOrStaticValue394 bool is_dynamic() const { return !static_value.has_value(); }
395 };
396
397 constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
398
399 // Use this class to represent the precise details of the error to enable
400 // special treatment.
401 enum class EvalErrorDetail : uint32_t {
402 // The evaluation result depends on dynamic values such as parameters and
403 // infeed. Therefore, the HLO's value cannot be statically evaluated.
404 kDynamicValueDependence = 0,
405 };
406
MakeEvalErrorDueToParamOrInfeed(const HloInstruction & eval_instruction)407 Status MakeEvalErrorDueToParamOrInfeed(const HloInstruction& eval_instruction) {
408 Status error = tensorflow::errors::FailedPrecondition(
409 "Failed to evaluate instruction (", eval_instruction.name(),
410 ") since it depends on infeed or parameters to its parent computation (",
411 eval_instruction.parent()->name(), ").");
412 std::string error_payload;
413 error_payload.resize(sizeof(EvalErrorDetail));
414 absl::little_endian::Store32(
415 const_cast<char*>(error_payload.data()),
416 static_cast<uint32_t>(EvalErrorDetail::kDynamicValueDependence));
417 error.SetPayload(kEvalErrorDetailUrl, error_payload);
418 return error;
419 }
420
ParseEvalErrorDetail(const Status & error)421 std::optional<EvalErrorDetail> ParseEvalErrorDetail(const Status& error) {
422 auto error_detail = error.GetPayload(kEvalErrorDetailUrl);
423 if (!error_detail.has_value() && error_detail->empty()) {
424 return std::nullopt;
425 }
426 return static_cast<EvalErrorDetail>(
427 absl::little_endian::Load32(error_detail->Flatten().data()));
428 }
429
430 // A convenience wrapper to compute the while loop's argument's init value at
431 // the given tuple_index. If the init value depends on parameters to the
432 // while loop's parent computation or infeed, we consider the init value
433 // dynamic.
EvaluateWhileLoopParamInitValue(HloInstruction * param_instruction,int64_t tuple_index)434 std::optional<DynamicOrStaticValue> EvaluateWhileLoopParamInitValue(
435 HloInstruction* param_instruction, int64_t tuple_index) {
436 if (param_instruction->opcode() != HloOpcode::kTuple) {
437 return std::nullopt;
438 }
439 HloInstruction* element_instruction =
440 param_instruction->mutable_operand(tuple_index);
441 HloEvaluator evaluator;
442 StatusOr<Literal> value = evaluator.Evaluate(
443 element_instruction, /*recursively_evaluate_nonconstant_operands=*/true);
444 if (value.ok()) {
445 if (element_instruction->shape().element_type() == PrimitiveType::PRED) {
446 return DynamicOrStaticValue{
447 static_cast<int64_t>(value->GetFirstElement<bool>())};
448 } else {
449 return DynamicOrStaticValue{value->GetFirstInteger()};
450 }
451 } else {
452 std::optional<EvalErrorDetail> eval_error_detail =
453 ParseEvalErrorDetail(value.status());
454 if (eval_error_detail.has_value() &&
455 *eval_error_detail == EvalErrorDetail::kDynamicValueDependence) {
456 return DynamicOrStaticValue{std::nullopt};
457 }
458 }
459 return std::nullopt;
460 }
461
462 } // namespace
463
PatternMatchParseWhileLoop(HloInstruction * while_op)464 std::optional<ParsedWhileLoop> PatternMatchParseWhileLoop(
465 HloInstruction* while_op) {
466 HloComputation* while_cond = while_op->while_condition();
467 HloComputation* while_body = while_op->while_body();
468 HloInstruction* while_operand = while_op->mutable_operand(0);
469 // Try to parse the loop condition comparison.
470 std::optional<WhileCondComparisonOrNoOp> loop_comparison_or_noop =
471 PatternMatchLoopCondComparison(while_cond->root_instruction());
472 if (!loop_comparison_or_noop.has_value()) {
473 return std::nullopt;
474 }
475 if (loop_comparison_or_noop->index() == 1) {
476 ParamIndexAndValue& parameter_index_and_value =
477 std::get<ParamIndexAndValue>(*loop_comparison_or_noop);
478 CHECK(parameter_index_and_value.param_index.has_value());
479 int64_t loop_cond_var_index = *parameter_index_and_value.param_index;
480 std::optional<DynamicOrStaticValue> noop_value =
481 EvaluateWhileLoopParamInitValue(while_operand, loop_cond_var_index);
482
483 if (noop_value.has_value()) {
484 if (noop_value->is_dynamic()) {
485 return kParsedDynamicWhileLoop;
486 } else if (*noop_value->static_value == 0) {
487 return ParsedWhileLoop{
488 ParsedStaticWhileLoop{/*trip_count=*/0,
489 /*induction_var_index=*/loop_cond_var_index,
490 /*induction_var_init_value=*/0,
491 /*step_size=*/0,
492 /*loop_bound=*/0}};
493 }
494 std::optional<bool> updated_loop_cond_var =
495 PatternMatchLoopCondVarOverride(while_body->root_instruction(),
496 loop_cond_var_index);
497 if (updated_loop_cond_var.has_value()) {
498 if (!*updated_loop_cond_var) {
499 return ParsedWhileLoop{
500 ParsedStaticWhileLoop{/*trip_count=*/1,
501 /*induction_var_index=*/loop_cond_var_index,
502 /*induction_var_init_value=*/0,
503 /*step_size=*/1,
504 /*loop_bound=*/1}};
505 } else {
506 // This is an infinite loop and we set trip_count to -1.
507 return ParsedWhileLoop{
508 ParsedStaticWhileLoop{/*trip_count=*/-1,
509 /*induction_var_index=*/loop_cond_var_index,
510 /*induction_var_init_value=*/0,
511 /*step_size=*/0,
512 /*loop_bound=*/1}};
513 }
514 }
515 }
516 return std::nullopt;
517 }
518 CHECK_EQ(loop_comparison_or_noop->index(), 0);
519 WhileCondComparison loop_comparison =
520 std::get<WhileCondComparison>(*loop_comparison_or_noop);
521 CHECK(loop_comparison.lhs.IsValid() && loop_comparison.rhs.IsValid());
522
523 // If the while loop condition comparison's both sides take an init value
524 // from the while loop's parent computation's parameter, the loop is dynamic.
525 if (while_operand->opcode() == HloOpcode::kParameter) {
526 if (loop_comparison.lhs.param_index.has_value() ||
527 loop_comparison.rhs.param_index.has_value()) {
528 return kParsedDynamicWhileLoop;
529 }
530 }
531
532 // We can't handle the case when the while loop argument is not a Tuple
533 // instruction.
534 if (while_operand->opcode() != HloOpcode::kTuple) {
535 return std::nullopt;
536 }
537
538 // If loop cond comparison LHS does not have a value defined inside the loop
539 // cond computation, try to evaluate its init value inside the while loop's
540 // parent computation.
541 if (!loop_comparison.lhs.value.has_value()) {
542 std::optional<DynamicOrStaticValue> lhs_init_value =
543 EvaluateWhileLoopParamInitValue(while_operand,
544 *loop_comparison.lhs.param_index);
545 if (lhs_init_value.has_value()) {
546 if (lhs_init_value->is_dynamic()) {
547 return kParsedDynamicWhileLoop;
548 } else {
549 loop_comparison.lhs.value = *(lhs_init_value->static_value);
550 }
551 } else {
552 return std::nullopt;
553 }
554 }
555
556 // If loop cond comparison RHS does not have a value defined inside the loop
557 // cond computation, try to evaluate its init value inside the while loop's
558 // parent computation.
559 if (!loop_comparison.rhs.value.has_value()) {
560 std::optional<DynamicOrStaticValue> rhs_init_value =
561 EvaluateWhileLoopParamInitValue(while_operand,
562 *loop_comparison.rhs.param_index);
563 if (rhs_init_value.has_value()) {
564 if (rhs_init_value->is_dynamic()) {
565 return kParsedDynamicWhileLoop;
566 } else {
567 loop_comparison.rhs.value = *(rhs_init_value->static_value);
568 }
569 } else {
570 return std::nullopt;
571 }
572 }
573
574 // We have either successfully evaluated the init value for both LHS and RHS
575 // or have returned as dynamic loop or failure.
576 CHECK(loop_comparison.lhs.value.has_value());
577 CHECK(loop_comparison.rhs.value.has_value());
578
579 if (loop_comparison.lhs.param_index.has_value()) {
580 VLOG(3) << __func__ << " lhs index: " << *loop_comparison.lhs.param_index;
581 }
582
583 VLOG(3) << __func__ << " lhs bound: " << *loop_comparison.lhs.value;
584
585 if (loop_comparison.rhs.param_index.has_value()) {
586 VLOG(3) << __func__ << " rhs index: " << *loop_comparison.rhs.param_index;
587 }
588
589 VLOG(3) << __func__ << " rhs bound: " << *loop_comparison.rhs.value;
590
591 // Check whether LHS is the loop induction var.
592 std::optional<int64_t> lhs_induction_var_update;
593 if (loop_comparison.lhs.param_index.has_value()) {
594 lhs_induction_var_update = PatternMatchInductionVarUpdate(
595 while_body->root_instruction(), *loop_comparison.lhs.param_index);
596 }
597
598 // Check whether LHS is the loop induction var.
599 std::optional<int64_t> rhs_induction_var_update;
600 if (loop_comparison.rhs.param_index.has_value()) {
601 rhs_induction_var_update = PatternMatchInductionVarUpdate(
602 while_body->root_instruction(), *loop_comparison.rhs.param_index);
603 }
604
605 // Lhs is the induction variable.
606 if (lhs_induction_var_update.has_value()) {
607 // We cannot handle the case when both LHS and RHS are updated inside
608 // the loop body.
609 if (rhs_induction_var_update.has_value() &&
610 *rhs_induction_var_update != 0) {
611 return std::nullopt;
612 }
613 if (*lhs_induction_var_update > 0 &&
614 (loop_comparison.comparson_direction == Comparison::Direction::kLt ||
615 loop_comparison.comparson_direction == Comparison::Direction::kLe)) {
616 int64_t trip_count =
617 (*loop_comparison.rhs.value - *loop_comparison.lhs.value - 1) /
618 *lhs_induction_var_update +
619 1;
620 // Additional logic to deal with Equal comparison.
621 if (loop_comparison.comparson_direction == Comparison::Direction::kLe &&
622 (*loop_comparison.rhs.value - *loop_comparison.lhs.value) %
623 *lhs_induction_var_update ==
624 0) {
625 trip_count += 1;
626 }
627 return ParsedWhileLoop{ParsedStaticWhileLoop{
628 /*trip_count=*/trip_count,
629 /*induction_var_index=*/*loop_comparison.lhs.param_index,
630 /*induction_var_init_value=*/*loop_comparison.lhs.value,
631 /*step_size=*/*lhs_induction_var_update,
632 /*loop_bound=*/*loop_comparison.rhs.value}};
633 } else if (*lhs_induction_var_update < 0 &&
634 (loop_comparison.comparson_direction ==
635 Comparison::Direction::kGt ||
636 loop_comparison.comparson_direction ==
637 Comparison::Direction::kGe)) {
638 int trip_count =
639 (*loop_comparison.lhs.value - *loop_comparison.rhs.value - 1) /
640 *lhs_induction_var_update +
641 1;
642 if (loop_comparison.comparson_direction == Comparison::Direction::kGe &&
643 (*loop_comparison.lhs.value - *loop_comparison.rhs.value) %
644 *lhs_induction_var_update ==
645 0) {
646 trip_count += 1;
647 }
648 return ParsedWhileLoop{ParsedStaticWhileLoop{
649 /*trip_count=*/trip_count,
650 /*induction_var_index=*/*(loop_comparison.lhs.param_index),
651 /*induction_var_init_value=*/*(loop_comparison.lhs.value),
652 /*step_size=*/-*lhs_induction_var_update,
653 /*loop_bound=*/*(loop_comparison.rhs.value)}};
654 }
655 return std::nullopt;
656 }
657 // Rhs is the induction variable.
658 if (rhs_induction_var_update.has_value()) {
659 // We cannot handle the case when both LHS and RHS are updated inside
660 // the loop body.
661 if (lhs_induction_var_update.has_value() &&
662 *lhs_induction_var_update == 0) {
663 return std::nullopt;
664 }
665 if (*rhs_induction_var_update > 0 &&
666 (loop_comparison.comparson_direction == Comparison::Direction::kGt ||
667 loop_comparison.comparson_direction == Comparison::Direction::kGe)) {
668 int trip_count =
669 (*loop_comparison.lhs.value - *loop_comparison.rhs.value - 1) /
670 *rhs_induction_var_update +
671 1;
672 if (loop_comparison.comparson_direction == Comparison::Direction::kGe &&
673 (*loop_comparison.lhs.value - *loop_comparison.rhs.value) %
674 *rhs_induction_var_update ==
675 0) {
676 trip_count += 1;
677 }
678 return ParsedWhileLoop{ParsedStaticWhileLoop{
679 /*trip_count=*/trip_count,
680 /*induction_var_index=*/*(loop_comparison.rhs.param_index),
681 /*induction_var_init_value=*/*(loop_comparison.rhs.value),
682 /*step_size=*/*rhs_induction_var_update,
683 /*loop_bound=*/*(loop_comparison.lhs.value)}};
684 } else if (*rhs_induction_var_update < 0 &&
685 (loop_comparison.comparson_direction ==
686 Comparison::Direction::kLt ||
687 loop_comparison.comparson_direction ==
688 Comparison::Direction::kLe)) {
689 int trip_count =
690 (*loop_comparison.rhs.value - *loop_comparison.lhs.value - 1) /
691 *rhs_induction_var_update +
692 1;
693 if (loop_comparison.comparson_direction == Comparison::Direction::kLe &&
694 (*loop_comparison.rhs.value - *loop_comparison.lhs.value) %
695 *rhs_induction_var_update ==
696 0) {
697 trip_count += 1;
698 }
699 return ParsedWhileLoop{ParsedStaticWhileLoop{
700 /*trip_count=*/trip_count,
701 /*induction_var_index=*/*(loop_comparison.rhs.param_index),
702 /*induction_var_init_value=*/*(loop_comparison.rhs.value),
703 /*step_size=*/-*rhs_induction_var_update,
704 /*loop_bound=*/*(loop_comparison.lhs.value)}};
705 }
706 return std::nullopt;
707 }
708 return std::nullopt;
709 }
710
711 // Note that unsupported types by the typed visitor does not necessarily imply
712 // the non-typed HloEvaluator (parent evaluator) would not support them either
713 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
714 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
715 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64_t max_loop_iterations)716 HloEvaluator::HloEvaluator(int64_t max_loop_iterations)
717 : max_loop_iterations_(max_loop_iterations) {
718 typed_visitors_[PRED] =
719 std::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
720 typed_visitors_[U8] =
721 std::make_unique<HloEvaluatorTypedVisitor<uint8_t>>(this);
722 typed_visitors_[U16] =
723 std::make_unique<HloEvaluatorTypedVisitor<uint16_t>>(this);
724 typed_visitors_[U32] =
725 std::make_unique<HloEvaluatorTypedVisitor<uint32_t>>(this);
726 typed_visitors_[U64] =
727 std::make_unique<HloEvaluatorTypedVisitor<uint64_t>>(this);
728 typed_visitors_[S8] =
729 std::make_unique<HloEvaluatorTypedVisitor<int8_t>>(this);
730 typed_visitors_[S16] =
731 std::make_unique<HloEvaluatorTypedVisitor<int16_t>>(this);
732 typed_visitors_[S32] =
733 std::make_unique<HloEvaluatorTypedVisitor<int32_t>>(this);
734 typed_visitors_[S64] =
735 std::make_unique<HloEvaluatorTypedVisitor<int64_t>>(this);
736 typed_visitors_[F16] =
737 std::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
738 typed_visitors_[F32] =
739 std::make_unique<HloEvaluatorTypedVisitor<float>>(this);
740 typed_visitors_[F64] =
741 std::make_unique<HloEvaluatorTypedVisitor<double>>(this);
742 typed_visitors_[C64] =
743 std::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
744 typed_visitors_[C128] =
745 std::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
746
747 // Most of the evaluator computations we use don't support BF16 (e.g.,
748 // std::ceil, std::tanh). To make evaluator work with BF16, we set all
749 // elementwise computations to be done in F32 and do BF16<->F32 conversion
750 // around the input and the output of the computations.
751 typed_visitors_[BF16] =
752 std::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
753
754 typed_visitors_[TUPLE] =
755 std::make_unique<FunctionVisitor>([](HloInstruction*) {
756 return Unimplemented(
757 "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
758 });
759 typed_visitors_[OPAQUE_TYPE] =
760 std::make_unique<FunctionVisitor>([](HloInstruction*) {
761 return Unimplemented(
762 "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE.");
763 });
764 typed_visitors_[TOKEN] =
765 std::make_unique<FunctionVisitor>([](HloInstruction*) {
766 return Unimplemented(
767 "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
768 });
769 }
770
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)771 StatusOr<Literal> HloEvaluator::Evaluate(
772 const HloComputation& computation,
773 absl::Span<const Literal* const> arg_literals) {
774 CHECK(computation.parent() != nullptr);
775 XLA_VLOG_LINES(
776 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
777
778 if (arg_literals.size() != computation.num_parameters()) {
779 return InvalidArgument(
780 "Expected %d argument%s, but got %d.", computation.num_parameters(),
781 computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
782 }
783 for (int64_t i = 0; i < arg_literals.size(); ++i) {
784 const auto& computation_shape =
785 computation.parameter_instruction(i)->shape();
786 const auto& arg_shape = arg_literals[i]->shape();
787 if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
788 arg_shape)) {
789 return InvalidArgument(
790 "Shape mismatch at parameter %d. Computation expected %s, but arg "
791 "was %s.",
792 i, ShapeUtil::HumanStringWithLayout(computation_shape),
793 ShapeUtil::HumanStringWithLayout(arg_shape));
794 }
795 }
796
797 evaluated_.clear();
798 arg_literals_.clear();
799 for (const auto& literal_ptr : arg_literals) {
800 arg_literals_.push_back(&*literal_ptr);
801 }
802
803 // Re-seed RNG, either from the configuration's seed or a monotonic
804 // per-evaluator seed (which prevents two evaluators from returning the same
805 // random sequence).
806 if (computation.parent()->config().seed()) {
807 seed_ = computation.parent()->config().seed();
808 } else {
809 // Start global_seed at a (true) random value.
810 static std::atomic<uint64_t> global_seed{std::random_device()()};
811 seed_ = global_seed.fetch_add(1);
812 }
813 engine_.seed(seed_);
814
815 TF_RETURN_IF_ERROR(computation.Accept(this));
816 const Literal& result =
817 GetEvaluatedLiteralFor(computation.root_instruction());
818 if (VLOG_IS_ON(100)) {
819 for (const HloInstruction* instr : computation.instructions()) {
820 VLOG(100) << instr->name() << " = " << GetEvaluatedLiteralFor(instr);
821 }
822 }
823 if (!result.IsKnown()) {
824 return MakeEvalErrorDueToParamOrInfeed(*computation.root_instruction());
825 }
826 return result.Clone();
827 }
828
Evaluate(HloInstruction * instruction,bool recursively_evaluate_nonconstant_operands)829 StatusOr<Literal> HloEvaluator::Evaluate(
830 HloInstruction* instruction,
831 bool recursively_evaluate_nonconstant_operands) {
832 arg_literals_.clear();
833 evaluated_.clear();
834 auto enable_partial_evaluation_cleanup =
835 absl::MakeCleanup([this] { enable_partial_evaluation_ = false; });
836 enable_partial_evaluation_ = recursively_evaluate_nonconstant_operands;
837 TF_RETURN_IF_ERROR(
838 EvaluateInternal(instruction, /*shape_index=*/{},
839 recursively_evaluate_nonconstant_operands));
840 const Literal& result = GetEvaluatedLiteralFor(instruction);
841 if (!result.IsKnown()) {
842 return MakeEvalErrorDueToParamOrInfeed(*instruction);
843 }
844 return result.Clone();
845 }
846
TryEvaluate(HloInstruction * instruction,Literal * result,bool recursively_evaluate_nonconstant_operands)847 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result,
848 bool recursively_evaluate_nonconstant_operands) {
849 CHECK(result != nullptr);
850 auto result_or =
851 Evaluate(instruction, recursively_evaluate_nonconstant_operands);
852 if (!result_or.ok()) {
853 VLOG(1) << "TryEvaluate failed:" << result_or.status();
854 return false;
855 }
856
857 *result = std::move(result_or).value();
858 return true;
859 }
860
EvaluateWithSubstitutions(const HloInstruction * instruction,const absl::flat_hash_map<const HloInstruction *,const Literal * > & substitutions)861 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
862 const HloInstruction* instruction,
863 const absl::flat_hash_map<const HloInstruction*, const Literal*>&
864 substitutions) {
865 std::vector<std::unique_ptr<HloInstruction>> owned_operands;
866 for (const HloInstruction* operand : instruction->operands()) {
867 auto it = substitutions.find(operand);
868 if (it == substitutions.end()) {
869 owned_operands.push_back(operand->Clone());
870 } else {
871 owned_operands.push_back(
872 HloInstruction::CreateConstant(it->second->Clone()));
873 }
874 }
875
876 std::vector<HloInstruction*> operands;
877 operands.reserve(owned_operands.size());
878 for (auto& operand : owned_operands) {
879 operands.push_back(operand.get());
880 }
881
882 std::unique_ptr<HloInstruction> cloned_instruction =
883 instruction->CloneWithNewOperands(instruction->shape(), operands);
884 auto result = Evaluate(cloned_instruction.get());
885
886 return result;
887 }
888
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)889 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
890 HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
891 std::unique_ptr<HloInstruction> lhs_instr =
892 HloInstruction::CreateConstant(lhs.Clone());
893 std::unique_ptr<HloInstruction> rhs_instr =
894 HloInstruction::CreateConstant(rhs.Clone());
895
896 std::unique_ptr<HloInstruction> cloned_instruction =
897 HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
898 rhs_instr.get());
899 auto result = Evaluate(cloned_instruction.get());
900
901 return result;
902 }
903
EvaluateElementwiseTernaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs,const Literal & ehs)904 StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
905 HloOpcode opcode, const Literal& lhs, const Literal& rhs,
906 const Literal& ehs) {
907 std::unique_ptr<HloInstruction> lhs_instr =
908 HloInstruction::CreateConstant(lhs.Clone());
909 std::unique_ptr<HloInstruction> rhs_instr =
910 HloInstruction::CreateConstant(rhs.Clone());
911 std::unique_ptr<HloInstruction> ehs_instr =
912 HloInstruction::CreateConstant(ehs.Clone());
913 TF_ASSIGN_OR_RETURN(auto output_shape,
914 ShapeInference::InferTernaryOpShape(
915 opcode, lhs.shape(), rhs.shape(), ehs.shape()));
916 std::unique_ptr<HloInstruction> cloned_instruction =
917 HloInstruction::CreateTernary(output_shape, opcode, lhs_instr.get(),
918 rhs_instr.get(), ehs_instr.get());
919 return Evaluate(cloned_instruction.get());
920 }
921
EvaluateElementwiseCompareOp(ComparisonDirection direction,const Literal & lhs,const Literal & rhs)922 StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
923 ComparisonDirection direction, const Literal& lhs, const Literal& rhs) {
924 std::unique_ptr<HloInstruction> lhs_instr =
925 HloInstruction::CreateConstant(lhs.Clone());
926 std::unique_ptr<HloInstruction> rhs_instr =
927 HloInstruction::CreateConstant(rhs.Clone());
928
929 std::unique_ptr<HloInstruction> cloned_instruction =
930 HloInstruction::CreateCompare(
931 ShapeUtil::ChangeElementType(lhs.shape(), PRED), lhs_instr.get(),
932 rhs_instr.get(), direction);
933 auto result = Evaluate(cloned_instruction.get());
934
935 return result;
936 }
937
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)938 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
939 HloOpcode opcode, const Literal& operand) {
940 std::unique_ptr<HloInstruction> operand_instr =
941 HloInstruction::CreateConstant(operand.Clone());
942
943 TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferUnaryOpShape(
944 opcode, operand.shape()));
945 std::unique_ptr<HloInstruction> cloned_instruction =
946 HloInstruction::CreateUnary(inferred_shape, opcode, operand_instr.get());
947 auto result = Evaluate(cloned_instruction.get());
948
949 return result;
950 }
951
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)952 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
953 const DotDimensionNumbers& dim_numbers,
954 const PrecisionConfig& precision_config, const Literal& lhs,
955 const Literal& rhs) {
956 std::unique_ptr<HloInstruction> lhs_instr =
957 HloInstruction::CreateConstant(lhs.Clone());
958 std::unique_ptr<HloInstruction> rhs_instr =
959 HloInstruction::CreateConstant(rhs.Clone());
960
961 TF_ASSIGN_OR_RETURN(
962 Shape dot_shape,
963 ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers,
964 /*preferred_element_type=*/std::nullopt));
965
966 std::unique_ptr<HloInstruction> cloned_instruction =
967 HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
968 dim_numbers, precision_config);
969 return Evaluate(cloned_instruction.get());
970 }
971
EvaluateInternal(HloInstruction * instruction,const ShapeIndex & shape_index,bool recursively_evaluate_nonconstant_operands)972 Status HloEvaluator::EvaluateInternal(
973 HloInstruction* instruction, const ShapeIndex& shape_index,
974 bool recursively_evaluate_nonconstant_operands) {
975 // Don't need to evaluate this instruction again if it has already been
976 // evaluated.
977 if (IsAlreadyEvaluated(instruction, shape_index)) {
978 return OkStatus();
979 }
980
981 if (!recursively_evaluate_nonconstant_operands) {
982 if (!hlo_query::AllOperandsAreConstants(*instruction)) {
983 return tensorflow::errors::FailedPrecondition(
984 "Not all operands are constants.");
985 }
986 } else {
987 if (instruction->opcode() == HloOpcode::kGetTupleElement) {
988 ShapeIndex new_shape_index = shape_index;
989 new_shape_index.push_front(instruction->tuple_index());
990 TF_RETURN_IF_ERROR(
991 EvaluateInternal(instruction->mutable_operand(0), new_shape_index,
992 /*recursively_evaluate_nonconstant_operands=*/true));
993 } else if (instruction->opcode() == HloOpcode::kTuple &&
994 !shape_index.empty()) {
995 ShapeIndex new_shape_index = shape_index;
996 int64_t tuple_index = new_shape_index.front();
997 new_shape_index.pop_front();
998 TF_RETURN_IF_ERROR(EvaluateInternal(
999 instruction->mutable_operand(tuple_index), new_shape_index,
1000 /*recursively_evaluate_nonconstant_operands=*/true));
1001 } else {
1002 for (HloInstruction* operand : instruction->operands()) {
1003 TF_RETURN_IF_ERROR(EvaluateInternal(
1004 operand, /*shape_index=*/{},
1005 /*recursively_evaluate_nonconstant_operands=*/true));
1006 // Except for the above and following cases, we do not support handling
1007 // unknown operands for other HLOs. So mark the result as unknown.
1008 if ((!GetEvaluatedLiteralFor(operand).IsKnown() &&
1009 instruction->opcode() != HloOpcode::kCopy &&
1010 instruction->opcode() != HloOpcode::kCopyStart &&
1011 instruction->opcode() != HloOpcode::kCopyDone &&
1012 instruction->opcode() != HloOpcode::kAsyncStart &&
1013 instruction->opcode() != HloOpcode::kAsyncUpdate &&
1014 instruction->opcode() != HloOpcode::kAsyncDone &&
1015 instruction->opcode() != HloOpcode::kWhile)) {
1016 evaluated_[instruction] =
1017 Literal::CreateFromShapeWithUnknownLeafArrays(
1018 instruction->shape());
1019 return OkStatus();
1020 }
1021 }
1022 }
1023 }
1024 visitor_shape_index_ = shape_index;
1025 TF_RETURN_IF_ERROR(Preprocess(instruction));
1026 TF_RETURN_IF_ERROR(instruction->Visit(this));
1027 TF_RETURN_IF_ERROR(Postprocess(instruction));
1028 return OkStatus();
1029 }
1030
HandleBitcast(HloInstruction * bitcast)1031 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
1032 const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
1033 Literal result(bitcast->shape());
1034 // Bitcast output is allowed to be smaller than the input if the backend-
1035 // specific buffer sizes for the input and output are the same. Since the HLO
1036 // evaluator doesn't have access to the backend-specific shape size function,
1037 // assume it's OK to bitcast if output <= input.
1038 TF_RET_CHECK(operand_literal.size_bytes() >= result.size_bytes());
1039 memcpy(result.untyped_data(), operand_literal.untyped_data(),
1040 result.size_bytes());
1041 evaluated_[bitcast] = std::move(result);
1042 return OkStatus();
1043 }
1044
HandleGetDimensionSize(HloInstruction * get_dimension_size)1045 Status HloEvaluator::HandleGetDimensionSize(
1046 HloInstruction* get_dimension_size) {
1047 HloInstruction* operand = get_dimension_size->mutable_operand(0);
1048 int64_t dim = get_dimension_size->dimension();
1049 if (dynamic_dimension_inference_ == nullptr) {
1050 return InvalidArgument(
1051 "Evaluator cannot evaluate get_dimension_size without "
1052 "set_dynamic_dimension_inference.");
1053 }
1054 HloInstruction* dynamic_size =
1055 dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
1056 if (dynamic_size != nullptr) {
1057 evaluated_[get_dimension_size] =
1058 GetEvaluatedLiteralFor(dynamic_size).Clone();
1059 return OkStatus();
1060 }
1061
1062 const Shape& shape = get_dimension_size->operand(0)->shape();
1063 Literal output(ShapeUtil::MakeShape(S32, {}));
1064 output.PopulateWithValue(
1065 static_cast<int32_t>(shape.dimensions(get_dimension_size->dimension())));
1066 evaluated_[get_dimension_size] = std::move(output);
1067 return OkStatus();
1068 }
1069
HandleSetDimensionSize(HloInstruction * set_dimension_size)1070 Status HloEvaluator::HandleSetDimensionSize(
1071 HloInstruction* set_dimension_size) {
1072 const Literal& operand_literal =
1073 GetEvaluatedLiteralFor(set_dimension_size->operand(0));
1074 Literal result(set_dimension_size->shape());
1075 memcpy(result.untyped_data(), operand_literal.untyped_data(),
1076 operand_literal.size_bytes());
1077 const Literal& size_literal =
1078 GetEvaluatedLiteralFor(set_dimension_size->operand(1));
1079 result.SetDynamicSize(set_dimension_size->dimension(),
1080 size_literal.Get<int32_t>({}));
1081 evaluated_[set_dimension_size] = std::move(result);
1082 return OkStatus();
1083 }
1084
HandleParameter(HloInstruction * parameter)1085 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
1086 if (arg_literals_.empty()) {
1087 if (!enable_partial_evaluation_) {
1088 return tensorflow::errors::FailedPrecondition(
1089 "Failed to evaluate instruction since its operands are unknown "
1090 "or undetermined and partial evaluation is not enabled.");
1091 }
1092 evaluated_[parameter] =
1093 Literal::CreateFromShapeWithUnknownLeafArrays(parameter->shape());
1094 return OkStatus();
1095 }
1096
1097 // Nothing to do other than sanity checks. Parameters' values are stored in
1098 // arg_literals_.
1099 CHECK_LT(parameter->parameter_number(), arg_literals_.size());
1100
1101 #ifndef NDEBUG
1102 const Literal* input_literal = arg_literals_[parameter->parameter_number()];
1103 VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
1104 DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
1105 input_literal->shape()))
1106 << "parameter shape is: "
1107 << ShapeUtil::HumanStringWithLayout(parameter->shape())
1108 << ", but input literal shape is: "
1109 << ShapeUtil::HumanStringWithLayout(input_literal->shape());
1110 #endif
1111
1112 return OkStatus();
1113 }
1114
HandleInfeed(HloInstruction * infeed)1115 Status HloEvaluator::HandleInfeed(HloInstruction* infeed) {
1116 if (!enable_partial_evaluation_) {
1117 return tensorflow::errors::FailedPrecondition(
1118 "Failed to evaluate instruction since its operands are unknown "
1119 "or undetermined and partial evaluation is not enabled.");
1120 }
1121 evaluated_[infeed] =
1122 Literal::CreateFromShapeWithUnknownLeafArrays(infeed->shape());
1123 return OkStatus();
1124 }
1125
HandleConstant(HloInstruction *)1126 Status HloEvaluator::HandleConstant(HloInstruction*) { return OkStatus(); }
1127
HandleReshape(HloInstruction * reshape)1128 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
1129 TF_ASSIGN_OR_RETURN(evaluated_[reshape],
1130 GetEvaluatedLiteralFor(reshape->operand(0))
1131 .Reshape(reshape->shape().dimensions()));
1132 return OkStatus();
1133 }
1134
HandleTranspose(HloInstruction * transpose)1135 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
1136 evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
1137 .Transpose(transpose->dimensions());
1138 return OkStatus();
1139 }
1140
HandleConcatenate(HloInstruction * concatenate)1141 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
1142 absl::Span<HloInstruction* const> operands(concatenate->operands());
1143 // The result concatenate dimension is going to be the sum of all
1144 // concatenate dimensions of the operands taking part of the operation.
1145 const Shape& reference_shape = operands[0]->shape();
1146 CHECK(reference_shape.IsArray());
1147 const int64_t rank = reference_shape.rank();
1148 const int64_t concat_dim = concatenate->dimensions()[0];
1149 CHECK_GE(concat_dim, 0);
1150 CHECK_LT(concat_dim, rank);
1151
1152 DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
1153 reference_shape.dimensions().end());
1154
1155 for (int64_t i = 1; i < operands.size(); ++i) {
1156 const Shape& operand_shape = operands[i]->shape();
1157 CHECK(operand_shape.IsArray());
1158 // Accumulate the concat dimension from all tensors taking part to the
1159 // operation.
1160 concat_dimensions[concat_dim] +=
1161 ShapeUtil::GetDimension(operand_shape, concat_dim);
1162 }
1163
1164 auto result_literal = LiteralUtil::CreateFromDimensions(
1165 reference_shape.element_type(), concat_dimensions);
1166 DimensionVector source_indices(rank, 0);
1167 DimensionVector dest_indices(concat_dimensions.size(), 0);
1168
1169 for (auto operand : operands) {
1170 const Shape& operand_shape = operand->shape();
1171 TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
1172 GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
1173 operand_shape.dimensions()));
1174 dest_indices[concat_dim] +=
1175 ShapeUtil::GetDimension(operand_shape, concat_dim);
1176 }
1177
1178 evaluated_[concatenate] = std::move(result_literal);
1179 return OkStatus();
1180 }
1181
HandleIsFinite(HloInstruction * is_finite)1182 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
1183 auto operand = is_finite->operand(0);
1184 auto elem_ty = operand->shape().element_type();
1185 switch (elem_ty) {
1186 case PRED:
1187 case TUPLE:
1188 case OPAQUE_TYPE:
1189 case TOKEN:
1190 case S8:
1191 case S16:
1192 case S32:
1193 case S64:
1194 case U8:
1195 case U16:
1196 case U32:
1197 case U64:
1198 case C64:
1199 case C128:
1200 // Explicitly enumerate all types in this switch so that when we add a new
1201 // type, we'll get a compile error here.
1202 case PRIMITIVE_TYPE_INVALID:
1203 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
1204 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
1205 return InvalidArgument(
1206 "expected element type in shape to be floating point, but "
1207 "got: %s",
1208 PrimitiveType_Name(elem_ty));
1209
1210 case F16: {
1211 auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
1212 is_finite,
1213 [](Eigen::half elem_operand) {
1214 return std::isfinite(static_cast<float>(elem_operand));
1215 },
1216 GetEvaluatedLiteralFor(operand));
1217 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1218 break;
1219 }
1220 case BF16: {
1221 auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
1222 is_finite,
1223 [](bfloat16 elem_operand) {
1224 return std::isfinite(static_cast<float>(elem_operand));
1225 },
1226 GetEvaluatedLiteralFor(operand));
1227 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1228 break;
1229 }
1230 case F32: {
1231 auto result_or = ElementWiseUnaryOpImpl<bool, float>(
1232 is_finite,
1233 [](float elem_operand) { return std::isfinite(elem_operand); },
1234 GetEvaluatedLiteralFor(operand));
1235 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1236 break;
1237 }
1238 case F64: {
1239 auto result_or = ElementWiseUnaryOpImpl<bool, double>(
1240 is_finite,
1241 [](double elem_operand) { return std::isfinite(elem_operand); },
1242 GetEvaluatedLiteralFor(operand));
1243 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
1244 break;
1245 }
1246 }
1247
1248 return OkStatus();
1249 }
1250
HandleReal(HloInstruction * real)1251 Status HloEvaluator::HandleReal(HloInstruction* real) {
1252 auto operand = real->operand(0);
1253 switch (operand->shape().element_type()) {
1254 case BF16: {
1255 auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
1256 real, [](bfloat16 elem_operand) { return elem_operand; },
1257 GetEvaluatedLiteralFor(operand));
1258 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1259 break;
1260 }
1261 case C64: {
1262 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
1263 real, [](complex64 elem_operand) { return std::real(elem_operand); },
1264 GetEvaluatedLiteralFor(operand));
1265 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1266 break;
1267 }
1268 case C128: {
1269 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
1270 real, [](complex128 elem_operand) { return std::real(elem_operand); },
1271 GetEvaluatedLiteralFor(operand));
1272 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1273 break;
1274 }
1275 case F16: {
1276 auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
1277 real, [](Eigen::half elem_operand) { return elem_operand; },
1278 GetEvaluatedLiteralFor(operand));
1279 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1280 break;
1281 }
1282 case F32: {
1283 auto result_or = ElementWiseUnaryOpImpl<float, float>(
1284 real, [](float elem_operand) { return elem_operand; },
1285 GetEvaluatedLiteralFor(operand));
1286 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1287 break;
1288 }
1289 case F64: {
1290 auto result_or = ElementWiseUnaryOpImpl<double, double>(
1291 real, [](double elem_operand) { return elem_operand; },
1292 GetEvaluatedLiteralFor(operand));
1293 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
1294 break;
1295 }
1296 default:
1297 LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
1298 << PrimitiveType_Name(operand->shape().element_type());
1299 }
1300
1301 return OkStatus();
1302 }
1303
HandleImag(HloInstruction * imag)1304 Status HloEvaluator::HandleImag(HloInstruction* imag) {
1305 auto operand = imag->operand(0);
1306 switch (operand->shape().element_type()) {
1307 case BF16: {
1308 auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
1309 imag, [](bfloat16 elem_operand) { return bfloat16(0); },
1310 GetEvaluatedLiteralFor(imag->operand(0)));
1311
1312 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1313 break;
1314 }
1315 case C64: {
1316 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
1317 imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
1318 GetEvaluatedLiteralFor(imag->operand(0)));
1319
1320 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1321 break;
1322 }
1323 case C128: {
1324 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
1325 imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
1326 GetEvaluatedLiteralFor(imag->operand(0)));
1327
1328 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1329 break;
1330 }
1331 case F16: {
1332 auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
1333 imag, [](Eigen::half elem_operand) { return Eigen::half(0); },
1334 GetEvaluatedLiteralFor(imag->operand(0)));
1335
1336 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1337 break;
1338 }
1339 case F32: {
1340 auto result_or = ElementWiseUnaryOpImpl<float, float>(
1341 imag, [](float elem_operand) { return 0; },
1342 GetEvaluatedLiteralFor(imag->operand(0)));
1343
1344 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1345 break;
1346 }
1347 case F64: {
1348 auto result_or = ElementWiseUnaryOpImpl<double, double>(
1349 imag, [](double elem_operand) { return 0; },
1350 GetEvaluatedLiteralFor(imag->operand(0)));
1351
1352 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
1353 break;
1354 }
1355 default:
1356 LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
1357 << PrimitiveType_Name(operand->shape().element_type());
1358 }
1359
1360 return OkStatus();
1361 }
1362
HandleComplex(HloInstruction * complex)1363 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
1364 const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
1365 const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
1366 TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
1367
1368 Literal result(complex->shape());
1369 switch (complex->shape().element_type()) {
1370 case C64: {
1371 TF_RETURN_IF_ERROR(result.Populate<complex64>(
1372 [&](absl::Span<const int64_t> multi_index) {
1373 return std::complex<float>(real.Get<float>(multi_index),
1374 imag.Get<float>(multi_index));
1375 }));
1376 break;
1377 }
1378 case C128: {
1379 TF_RETURN_IF_ERROR(result.Populate<complex128>(
1380 [&](absl::Span<const int64_t> multi_index) {
1381 return std::complex<double>(real.Get<double>(multi_index),
1382 imag.Get<double>(multi_index));
1383 }));
1384 break;
1385 }
1386 default:
1387 LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
1388 << PrimitiveType_Name(complex->shape().element_type());
1389 }
1390
1391 evaluated_[complex] = std::move(result);
1392 return OkStatus();
1393 }
1394
HandleCompare(HloInstruction * compare)1395 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
1396 ComparisonDirection direction = compare->comparison_direction();
1397 auto lhs = compare->operand(0);
1398 auto rhs = compare->operand(1);
1399 DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
1400 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
1401
1402 TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
1403
1404 const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
1405 const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
1406
1407 // Note here we switch on the operand's type.
1408 switch (lhs->shape().element_type()) {
1409 case PRED: {
1410 TF_ASSIGN_OR_RETURN(
1411 evaluated_[compare],
1412 Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
1413 } break;
1414 case U8: {
1415 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1416 Compare<uint8_t>(compare->shape(), direction,
1417 lhs_literal, rhs_literal));
1418 } break;
1419 case U16: {
1420 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1421 Compare<uint16_t>(compare->shape(), direction,
1422 lhs_literal, rhs_literal));
1423 } break;
1424 case U32: {
1425 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1426 Compare<uint32_t>(compare->shape(), direction,
1427 lhs_literal, rhs_literal));
1428 } break;
1429 case U64: {
1430 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1431 Compare<uint64_t>(compare->shape(), direction,
1432 lhs_literal, rhs_literal));
1433 } break;
1434 case S8: {
1435 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1436 Compare<int8_t>(compare->shape(), direction,
1437 lhs_literal, rhs_literal));
1438 } break;
1439 case S16: {
1440 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1441 Compare<int16_t>(compare->shape(), direction,
1442 lhs_literal, rhs_literal));
1443 } break;
1444 case S32: {
1445 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1446 Compare<int32_t>(compare->shape(), direction,
1447 lhs_literal, rhs_literal));
1448 } break;
1449 case S64: {
1450 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1451 Compare<int64_t>(compare->shape(), direction,
1452 lhs_literal, rhs_literal));
1453 } break;
1454 case F16: {
1455 TF_ASSIGN_OR_RETURN(
1456 evaluated_[compare],
1457 Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
1458 } break;
1459 case BF16: {
1460 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1461 Compare<bfloat16>(compare->shape(), direction,
1462 lhs_literal, rhs_literal));
1463 } break;
1464 case F32: {
1465 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1466 Compare<float>(compare->shape(), direction,
1467 lhs_literal, rhs_literal));
1468 } break;
1469 case F64: {
1470 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1471 Compare<double>(compare->shape(), direction,
1472 lhs_literal, rhs_literal));
1473 } break;
1474 case C64: {
1475 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1476 Compare<complex64>(compare->shape(), direction,
1477 lhs_literal, rhs_literal));
1478 } break;
1479 case C128: {
1480 TF_ASSIGN_OR_RETURN(evaluated_[compare],
1481 Compare<complex128>(compare->shape(), direction,
1482 lhs_literal, rhs_literal));
1483 } break;
1484 default:
1485 LOG(FATAL) << "HandleCompare: unknown primitive type: "
1486 << PrimitiveType_Name(lhs->shape().element_type());
1487 }
1488
1489 return OkStatus();
1490 }
1491
HandleTuple(HloInstruction * tuple)1492 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
1493 std::vector<const Literal*> operand_literals;
1494 std::vector<Literal> operand_literal_values;
1495 if (!visitor_shape_index_.empty()) {
1496 // We only need to evaluate tuple at visitor_shape_index_. The other
1497 // operands might not have been evaluated, so mark the other operands as
1498 // undetermined.
1499 int64_t tuple_index = visitor_shape_index_.front();
1500 operand_literal_values.resize(tuple->operand_count());
1501 for (int operand_index = 0; operand_index < tuple->operand_count();
1502 ++operand_index) {
1503 if (operand_index == tuple_index) {
1504 operand_literals.push_back(
1505 &GetEvaluatedLiteralFor(tuple->mutable_operand(operand_index)));
1506 } else {
1507 operand_literal_values[operand_index] =
1508 Literal::CreateFromShapeWithUndeterminedLeafArrays(
1509 ShapeUtil::GetSubshape(tuple->shape(), {operand_index}));
1510 operand_literals.push_back(&operand_literal_values[operand_index]);
1511 }
1512 }
1513 } else {
1514 for (auto operand : tuple->operands()) {
1515 operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
1516 }
1517 }
1518
1519 if (evaluated_.contains(tuple)) {
1520 Literal new_result = LiteralUtil::MakeTuple(operand_literals);
1521 CHECK(new_result.IsDetermined(visitor_shape_index_));
1522 TF_RETURN_IF_ERROR(
1523 evaluated_[tuple].CopyFrom(new_result,
1524 /*dest_shape_index=*/visitor_shape_index_,
1525 /*src_shape_index=*/visitor_shape_index_));
1526 } else {
1527 evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
1528 }
1529 return OkStatus();
1530 }
1531
1532 namespace {
1533
1534 // These helper templates convert the data type and are intended to be used only
1535 // within the DFT implementation below. The special case is IRFFT, where the
1536 // specialization drops imaginary parts of complex values and returns real
1537 // numbers.
1538 template <typename ToType, typename FromType>
1539 struct TypeConverter {
GetAsxla::__anon9f8d6ac12511::TypeConverter1540 static inline ToType GetAs(FromType value) {
1541 return static_cast<ToType>(value);
1542 }
1543 };
1544
1545 template <typename FromType>
1546 struct TypeConverter<float, FromType> {
GetAsxla::__anon9f8d6ac12511::TypeConverter1547 static inline float GetAs(FromType value) {
1548 return static_cast<float>(value.real());
1549 }
1550 };
1551
1552 // This class implements the discrete Fourier transform. All transform types
1553 // (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary rank and
1554 // length of each dimension of the transform, and arbitrary layouts of the input
1555 // and output literals. The class template parameter must be a complex type, and
1556 // all internal calculations will be performed using this type.
1557 //
1558 // The input literal provides input data, which must be complex64 for FFT, IFFT,
1559 // IRFFT transforms and float for RFFT. The transform is computed over the
1560 // innermost dimensions of the input, thus the rank of the input data must be
1561 // same as fft_rank or larger. The input is expected to provide Ni values along
1562 // each transform axis with one exception: for IRFFT, only (N0 / 2) + 1 values
1563 // are needed along the X axis (the innermost index). To increase flexibility,
1564 // this implementation can handle mismatches between the input size and
1565 // transform lengths by either dropping extra input values or using zeroes in
1566 // place of missing input values as necessary. If the input data has rank higher
1567 // than the transform, the transform is applied for each valid combination of
1568 // the higher-ranking indices.
1569 //
1570 // The output contains complex64 values for FFT, IFFT, RFFT, and float values
1571 // for IRFFT. The rank of the output as well as the sizes of the dimensions
1572 // above the rank of the transform must match those of the input. Sizes of the
1573 // output's "fft_rank" innermost dimensions are expected to match the length of
1574 // the transform along respective axes with one exception: for RFFT, the output
1575 // is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
1576 // length(s) mismatch, the FFT output is trimmed to fit into the provided output
1577 // shape, or the output is padded with zero values appropriately.
1578 //
1579 // For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
1580 // input array will perform two transforms over the [][15][17] data in the sub
1581 // arrays [0][][] and [1][][], dropping the values along axis X and padding axis
1582 // Y with zeroes to create 16x16 working sets, and generating
1583 // complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
1584 // complex64[64][16][9] input array will use all input values and will produce
1585 // float[64][16][16] output.
1586 //
1587 // The implementation of the 1D transform for lengths, that are powers of 2, is
1588 // the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform
1589 // lengths, a straightforward, but slow, loop nest is used. The transforms of
1590 // higher ranks apply sets of 1D transforms along each axis. For example, the 2D
1591 // transform is computed by applying 1D transforms to each column followed by
1592 // applying 1D transforms to each row.
1593 //
1594 // In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
1595 // time, where Ni is the length of the transform's i-th dimension. However, for
1596 // dimension lengths, which are powers of 2, the run time along these dimensions
1597 // is reduced to log(Ni) in the summation, giving the runtime of
1598 // O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case.
1599 //
1600 template <typename ComplexType>
1601 class FftTransform {
1602 public:
FftTransform(HloInstruction * fft)1603 explicit FftTransform(HloInstruction* fft)
1604 : fft_type_(fft->fft_type()),
1605 fft_rank_(fft->fft_length().size()),
1606 fft_lengths_(fft->fft_length()) {
1607 // Make fft_lengths_[0] the minormost dimension.
1608 absl::c_reverse(fft_lengths_);
1609 }
1610
ComputeFft(HloInstruction * fft,const Literal & input_literal,Literal * output_literal)1611 Status ComputeFft(HloInstruction* fft, const Literal& input_literal,
1612 Literal* output_literal) {
1613 const Shape& input_shape = input_literal.shape();
1614 const Shape& output_shape = fft->shape();
1615
1616 TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape));
1617
1618 const auto fft_strides = ComputeStrides(fft_lengths_);
1619
1620 // Working set size.
1621 const int64_t fft_size = fft_strides[fft_rank_];
1622
1623 if (fft_size > 0) {
1624 // Linearized working data set.
1625 std::vector<ComplexType> data(fft_size);
1626
1627 // Temporary buffer allocated once and used in 1D sweeps. For dimension
1628 // length values that are powers of 2, the buffer should be twice as large
1629 // to simultaneously hold input and output in Fft1D() above.
1630 int64_t buffer_size = 0;
1631 for (auto len : fft_lengths_) {
1632 int64_t size =
1633 absl::has_single_bit(static_cast<uint64_t>(len)) ? len * 2 : len;
1634 buffer_size = std::max(buffer_size, size);
1635 }
1636 std::vector<ComplexType> buffer(buffer_size);
1637
1638 // Sizes of each axis of input and output literals.
1639 const auto input_lengths = GetDimensionLengths(input_literal);
1640 const auto output_lengths = GetDimensionLengths(*output_literal);
1641
1642 // Strides for generating linearized indices into multidimensional arrays.
1643 const auto input_strides = ComputeStrides(input_lengths, input_literal);
1644 const auto output_strides =
1645 ComputeStrides(output_lengths, *output_literal);
1646
1647 // Visit all elements in the dimensions with ranks above the FFT rank. For
1648 // each such element invoke the transform. Use separate indices for the
1649 // input and the output to allow different layouts.
1650 auto base_case = [&](int64_t axis, int64_t output_index,
1651 int64_t input_index, bool within_src_bounds) {
1652 if (axis == fft_rank_ - 1) {
1653 // Base case: copy the data from the input literal, apply the
1654 // transform, and copy the result to the output literal.
1655 CHECK(within_src_bounds);
1656 bool input_is_zero = CopyDataFromInput(
1657 input_literal, input_index, fft_size, fft_lengths_, fft_strides,
1658 input_lengths, input_strides, absl::MakeSpan(data));
1659 if (!input_is_zero) {
1660 // Make 1D sweeps along each transform axis.
1661 Sweep(fft_lengths_, fft_strides, absl::MakeSpan(data),
1662 absl::MakeSpan(buffer));
1663 }
1664 CopyDataToOutput(absl::MakeSpan(data), output_index, fft_lengths_,
1665 fft_strides, output_lengths, output_strides,
1666 output_literal);
1667 return true;
1668 }
1669 return false;
1670 };
1671 GenerateIndices(output_lengths, output_strides, input_lengths,
1672 input_strides, input_shape.rank(), 0, 0, base_case);
1673 }
1674
1675 return OkStatus();
1676 }
1677
1678 private:
1679 // Common code used by 1D implementations, which copies data from the input to
1680 // the contiguous buffer. Returns true if all copied values are zero.
GatherToBuffer(absl::Span<ComplexType> data,int64_t length,int64_t start,int64_t stride,bool expand_input,absl::Span<ComplexType> buffer)1681 static bool GatherToBuffer(absl::Span<ComplexType> data, int64_t length,
1682 int64_t start, int64_t stride, bool expand_input,
1683 absl::Span<ComplexType> buffer) {
1684 CHECK_GE(buffer.size(), length);
1685 bool input_is_zero = true;
1686 const int64_t ub = expand_input ? length / 2 + 1 : length;
1687 CHECK_GE(data.size(), start + (ub - 1) * stride);
1688 for (int64_t k = 0; k < ub; k++) {
1689 ComplexType value = data[start + k * stride];
1690 input_is_zero &= value == ComplexType(0.0, 0.0);
1691 buffer[k] = value;
1692 if (expand_input) {
1693 // Use conjugates of the values at indices [1 ... (ub - 2)] when the
1694 // length is even and at indices [1 ... (ub - 1)] when the length is odd
1695 // to calculate missing values at indices [(length - 1) ... ub].
1696 if (k > 0 && k < (length - ub + 1)) {
1697 buffer[length - k] = std::conj(value);
1698 }
1699 }
1700 }
1701 return input_is_zero;
1702 }
1703
1704 // Returns (conjugated, if 'inverse' is true) k-th twiddle for the given
1705 // length.
Twiddle(int64_t k,int64_t length,bool inverse)1706 static inline ComplexType Twiddle(int64_t k, int64_t length, bool inverse) {
1707 auto coeff = std::exp(ComplexType(0.0, -2.0 * M_PI * k / length));
1708 return inverse ? std::conj(coeff) : coeff;
1709 }
1710
1711 // Straightforward implementation of 1D DFT transform of arbitrary length.
1712 // Uses passed-in start index and stride to gather inputs from the data vector
1713 // into the preallocated buffer, computes the result, and writes it back to
1714 // the same locations in the data vector. Runs in O(length^2) time.
1715 //
1716 // Parameters contract_output and expand_input are used to avoid unnecessary
1717 // calculations. When contract_output is set to true, then only (length / 2) +
1718 // 1 output values are computed. When expand_input is set to true, then
1719 // (length / 2) + 1 values from the data set are used to re-create the full
1720 // set of size 'length', on which the transform is then performed.
1721 //
NaiveDft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1722 static void NaiveDft1D(int64_t length, int64_t start, int64_t stride,
1723 bool inverse, bool contract_output, bool expand_input,
1724 absl::Span<ComplexType> data,
1725 absl::Span<ComplexType> buffer) {
1726 const bool input_is_zero =
1727 GatherToBuffer(data, length, start, stride, expand_input, buffer);
1728
1729 if (!input_is_zero) {
1730 const int64_t ub = contract_output ? length / 2 + 1 : length;
1731 for (int64_t k = 0; k < ub; k++) {
1732 ComplexType value = ComplexType(0.0, 0.0);
1733 for (int n = 0; n < length; n++) {
1734 value += buffer[n] * Twiddle(n * k, length, inverse);
1735 }
1736 data[start + k * stride] =
1737 inverse ? value / ComplexType(length, 0.0) : value;
1738 }
1739 }
1740 }
1741
1742 // Non-recursive implementation of the Cooley-Tukey radix-2 decimation in
1743 // time. Performs 1D FFT transform for the lengths, which are powers of 2.
1744 // Runs in O(length * log(length)) time. Uses the same parameters as the naive
1745 // implementation above, except that the preallocated buffer must be at least
1746 // twice as big as the length of the transform, because the buffer is used to
1747 // hold both input and output values for each stage of the transform.
1748 //
Fft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1749 static void Fft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1750 bool contract_output, bool expand_input,
1751 absl::Span<ComplexType> data,
1752 absl::Span<ComplexType> buffer) {
1753 CHECK(absl::has_single_bit(static_cast<uint64_t>(length)));
1754 const bool input_is_zero =
1755 GatherToBuffer(data, length, start, stride, expand_input, buffer);
1756
1757 if (!input_is_zero) {
1758 auto generate_twiddles = [](int64_t length, bool inverse) {
1759 std::vector<ComplexType> twiddles;
1760 // Need only half the twiddles.
1761 for (int64_t k = 0; k < length / 2; k++) {
1762 twiddles.push_back(Twiddle(k, length, inverse));
1763 }
1764 return twiddles;
1765 };
1766
1767 // Indices into the parts of the buffer used for input and output values.
1768 int64_t in_base = length;
1769 int64_t out_base = 0;
1770
1771 // At each stage, we "split" the input data into num_blocks, with
1772 // block_size values in each block.
1773 for (int64_t num_blocks = 1; num_blocks < length; num_blocks *= 2) {
1774 // Swap input and output parts of the buffer.
1775 std::swap(in_base, out_base);
1776 auto twiddles = generate_twiddles(num_blocks * 2, inverse);
1777 const int64_t block_size = length / num_blocks;
1778 const int64_t next_iteration_block_size = block_size / 2;
1779 for (int64_t block = 0; block < num_blocks; block++) {
1780 const int64_t in_offset = in_base + block * block_size;
1781 const int64_t out_offset =
1782 out_base + block * next_iteration_block_size;
1783 // For each (even, odd) pair of values in the block, calculate two
1784 // output values as even + twiddle * odd and even - twiddle * odd.
1785 for (int64_t pair = 0; pair < block_size / 2; pair++) {
1786 const ComplexType even = buffer[in_offset + pair];
1787 const ComplexType odd = buffer[in_offset + block_size / 2 + pair];
1788 const ComplexType twiddled_odd = twiddles[block] * odd;
1789 buffer[out_offset + pair] = even + twiddled_odd;
1790 buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
1791 }
1792 }
1793 }
1794 // Copy computed result back to data.
1795 const int64_t ub = contract_output ? length / 2 + 1 : length;
1796 for (int64_t k = 0; k < ub; k++) {
1797 ComplexType value = buffer[out_base + k];
1798 data[start + k * stride] =
1799 inverse ? value / ComplexType(length, 0.0) : value;
1800 }
1801 }
1802 }
1803
1804 // Determine, which implementation of 1D transform to use and call it.
Dft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1805 static void Dft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1806 bool contract_output, bool expand_input,
1807 absl::Span<ComplexType> data,
1808 absl::Span<ComplexType> buffer) {
1809 if (absl::has_single_bit(static_cast<uint64_t>(length))) {
1810 Fft1D(length, start, stride, inverse, contract_output, expand_input, data,
1811 buffer);
1812 } else {
1813 NaiveDft1D(length, start, stride, inverse, contract_output, expand_input,
1814 data, buffer);
1815 }
1816 }
1817
1818 // Helper to reverse the order of dimension lengths in the passed-in literal.
GetDimensionLengths(const Literal & literal)1819 static std::vector<int64_t> GetDimensionLengths(const Literal& literal) {
1820 auto dimensions = literal.shape().dimensions();
1821 return std::vector<int64_t>(dimensions.rbegin(), dimensions.rend());
1822 }
1823
1824 // Helper to compute strides for creating linear indices into multidimensional
1825 // data from the dimension lengths and the layout. Returns a new vector of
1826 // size lengths.size() + 1. The last element of the returned vector at index
1827 // [lengths.size()] contains the product of all dimension lengths.
ComputeStrides(const absl::Span<const int64_t> lengths,const Layout & layout)1828 static std::vector<int64_t> ComputeStrides(
1829 const absl::Span<const int64_t> lengths, const Layout& layout) {
1830 const int64_t num_dimensions = lengths.size();
1831
1832 // Make sure that the layout length matches the number of dimensions.
1833 CHECK_EQ(num_dimensions, layout.minor_to_major_size());
1834
1835 // Calculate strides using layout-specified ordering of the dimensions and
1836 // place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
1837 std::vector<int64_t> strides(num_dimensions + 1);
1838 int64_t stride = 1;
1839 for (int64_t i = 0; i < num_dimensions; i++) {
1840 // Reverse the ordering of the dimensions in the layout.
1841 const int64_t index = (num_dimensions - 1) - layout.minor_to_major(i);
1842 strides[index] = stride;
1843 stride *= lengths[index];
1844 }
1845 strides[num_dimensions] = stride;
1846
1847 return strides;
1848 }
1849
1850 // Compute strides as above using the default layout.
ComputeStrides(const absl::Span<const int64_t> lengths)1851 static std::vector<int64_t> ComputeStrides(
1852 const absl::Span<const int64_t> lengths) {
1853 return ComputeStrides(lengths,
1854 LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
1855 }
1856
1857 // Compute strides as above using the layout from the literal, if available.
ComputeStrides(const absl::Span<const int64_t> lengths,const Literal & literal)1858 static std::vector<int64_t> ComputeStrides(
1859 const absl::Span<const int64_t> lengths, const Literal& literal) {
1860 return literal.shape().has_layout()
1861 ? ComputeStrides(lengths, literal.shape().layout())
1862 : ComputeStrides(lengths);
1863 }
1864
1865 // Make 1D sweeps along each transform axis.
Sweep(const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1866 void Sweep(const absl::Span<const int64_t> fft_lengths,
1867 const absl::Span<const int64_t> fft_strides,
1868 absl::Span<ComplexType> data, absl::Span<ComplexType> buffer) {
1869 const bool inverse =
1870 fft_type_ == FftType::IFFT || fft_type_ == FftType::IRFFT;
1871 const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1872 const bool output_is_truncated = fft_type_ == FftType::RFFT;
1873
1874 // Recursively visit each column of the data along the sweep_axis. Calculate
1875 // linearized index of that column's first element and the stride, then
1876 // invoke 1D transform. For RFFT, avoid calculating unused output values:
1877 // first, compute only (length_x / 2) + 1 values along the X axis, then
1878 // limit the X coordinate to [0 ... (length / 2)] during the sweeps along
1879 // other axes. Similarly, for IRFFT sweep along higher dimensions first,
1880 // while keeping the X coordinate in the [0 ... (length / 2)] range, then
1881 // re-create negative frequencies omitted in the input and perform the
1882 // full-length transform along the X axis in the last sweep.
1883 std::function<void(int64_t, int64_t, int64_t)> sweep =
1884 [&](int64_t sweep_axis, int64_t axis, int64_t start) {
1885 if (axis < 0) {
1886 // Base case: invoke 1D transform.
1887 const int64_t length = fft_lengths[sweep_axis];
1888 const int64_t stride = fft_strides[sweep_axis];
1889 const bool expand_input = input_is_truncated && sweep_axis == 0;
1890 const bool contract_oputput =
1891 output_is_truncated && sweep_axis == 0;
1892 Dft1D(length, start, stride, inverse, contract_oputput,
1893 expand_input, data, buffer);
1894 } else if (axis == sweep_axis) {
1895 // Visit only the elements with coordinate 0 along the sweep axis.
1896 sweep(sweep_axis, axis - 1, start);
1897 } else {
1898 const int64_t length = fft_lengths[axis];
1899 const bool is_truncated = input_is_truncated || output_is_truncated;
1900 const int64_t ub =
1901 is_truncated && axis == 0 ? (length / 2) + 1 : length;
1902 for (int64_t i = 0; i < ub; i++) {
1903 sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
1904 }
1905 }
1906 };
1907 if (input_is_truncated) {
1908 // Sweep along the X axis last for IRFFT.
1909 for (int64_t sweep_axis = fft_rank_ - 1; sweep_axis >= 0; sweep_axis--) {
1910 sweep(sweep_axis, fft_rank_ - 1, 0);
1911 }
1912 } else {
1913 // Sweep along the X axis first for RFFT. The order does not matter for
1914 // FFT and IFFT types; handle them here as well.
1915 for (int64_t sweep_axis = 0; sweep_axis < fft_rank_; sweep_axis++) {
1916 sweep(sweep_axis, fft_rank_ - 1, 0);
1917 }
1918 }
1919 }
1920
1921 // This template generates two linearized indices, which can be used to access
1922 // multidimensional arrays. It uses a recursive function, which passes the
1923 // indices to the user-supplied callback function. The destination index is
1924 // always within dst_lengths[] bounds. The boolean parameter within_src_bounds
1925 // indicates whether the source index is within src_lengths[] bounds.
1926 //
1927 // The value returned from the callback function controls the recursion depth.
1928 // Returning true indicates that the base case had been hit and the recursion
1929 // stops. Otherwise, the recursion proceeds along the next less-major axis.
1930 //
1931 // For example, the base case when the axis value becomes negative invokes the
1932 // callback function for each possible index within dst_lengths[] bounds. The
1933 // base case when the axis value is equal to zero limits the indices to point
1934 // only to first elements along the minor-most dimension, allowing the
1935 // callback function to handle all values along the X axis.
1936 //
1937 template <typename BaseFn>
GenerateIndices(const absl::Span<const int64_t> dst_lengths,const absl::Span<const int64_t> dst_strides,const absl::Span<const int64_t> src_lengths,const absl::Span<const int64_t> src_strides,int64_t rank,int64_t dst_start,int64_t src_start,BaseFn && base)1938 static void GenerateIndices(const absl::Span<const int64_t> dst_lengths,
1939 const absl::Span<const int64_t> dst_strides,
1940 const absl::Span<const int64_t> src_lengths,
1941 const absl::Span<const int64_t> src_strides,
1942 int64_t rank, int64_t dst_start,
1943 int64_t src_start, BaseFn&& base) {
1944 CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
1945 CHECK_GE(dst_lengths.size(), rank);
1946 CHECK_EQ(src_lengths.size() + 1, src_strides.size());
1947 CHECK_GE(src_lengths.size(), rank);
1948
1949 std::function<void(int64_t, int64_t, int64_t, bool)> generate =
1950 [&](int64_t axis, int64_t dst_index, int64_t src_index,
1951 bool within_src_bounds) {
1952 if (!base(axis, dst_index, src_index, within_src_bounds)) {
1953 for (int64_t i = 0; i < dst_lengths[axis]; i++) {
1954 // Because the loop goes over dst_lengths[], the source index may
1955 // be out of src_lengths[] bounds. In this case, within_src_bounds
1956 // is false.
1957 within_src_bounds &= i < src_lengths[axis];
1958 generate(axis - 1, dst_index, src_index, within_src_bounds);
1959 dst_index += dst_strides[axis];
1960 src_index += src_strides[axis];
1961 }
1962 }
1963 };
1964 generate(rank - 1, dst_start, src_start, true);
1965 }
1966
1967 // Copies the input data from a literal to a pre-allocated vector. The sizes
1968 // of the input and the transform do not need to match. For each axis of the
1969 // transform, any extra input values beyond the transform length are ignored.
1970 // Conversely, if the input does not contain enough elements along any axis,
1971 // the data is padded with zeroes.
1972 //
1973 // For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
1974 // where length_x is the size of the full transform along the X axis.
1975 //
1976 // The input literal may have a rank higher than the rank of the transform.
1977 // Passed-in input_index value points to the first element of the input
1978 // literal to be copied.
1979 //
1980 // Returns true if all values in the work data set are zeroes.
1981 //
1982 template <typename InputType>
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> input_lengths,const absl::Span<const int64_t> input_strides,absl::Span<ComplexType> data)1983 bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
1984 int64_t fft_size,
1985 const absl::Span<const int64_t> fft_lengths,
1986 const absl::Span<const int64_t> fft_strides,
1987 const absl::Span<const int64_t> input_lengths,
1988 const absl::Span<const int64_t> input_strides,
1989 absl::Span<ComplexType> data) {
1990 CHECK_GE(data.size(), fft_size);
1991
1992 const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1993
1994 // Recursively visit each transform dimension to copy input values to the
1995 // working data set. The base case handles inputs along the X axis.
1996 bool input_is_zero = true;
1997 const InputType* input_data = input_literal.data<InputType>().data();
1998 auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
1999 bool within_src_bounds) {
2000 if (axis == 0) {
2001 // For IRFFT, the negative frequencies are only needed for the sweep
2002 // along the X axis, which is performed last. Leave this part of the
2003 // working set uninitialized until then.
2004 const int64_t length = fft_lengths[axis];
2005 const int64_t ub = input_is_truncated ? (length / 2) + 1 : length;
2006 for (int64_t i = 0; i < ub; i++) {
2007 ComplexType value = ComplexType(0);
2008 // Read input value only if the index is within bounds.
2009 if (within_src_bounds && i < input_lengths[axis]) {
2010 value = TypeConverter<ComplexType, InputType>::GetAs(
2011 input_data[src_index + i * input_strides[axis]]);
2012 input_is_zero &= value == ComplexType(0.0, 0.0);
2013 }
2014 data[dst_index + i * fft_strides[axis]] = value;
2015 }
2016 return true;
2017 }
2018 return false;
2019 };
2020 GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
2021 fft_rank_, 0, input_start, base_case);
2022 return input_is_zero;
2023 }
2024
2025 // Copies the result of the transform to the literal output. The sizes of the
2026 // transform and output must match.
2027 //
2028 // For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
2029 // the size of the full transform along the X axis (the most minor dimension).
2030 //
2031 // The output literal may have a rank higher than the rank of the transform.
2032 // Passed-in output_index value points to the first element of the output
2033 // literal to be filled in.
2034 //
2035 template <typename OutputType>
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> output_lengths,const absl::Span<const int64_t> output_strides,Literal * output_literal)2036 void CopyDataToOutput(const absl::Span<ComplexType> data,
2037 int64_t output_start,
2038 const absl::Span<const int64_t> fft_lengths,
2039 const absl::Span<const int64_t> fft_strides,
2040 const absl::Span<const int64_t> output_lengths,
2041 const absl::Span<const int64_t> output_strides,
2042 Literal* output_literal) {
2043 const bool output_is_truncated = fft_type_ == FftType::RFFT;
2044
2045 // Base case for recursive copy of the results to the output. The code
2046 // avoids making a recursive call for each output element by handling axis 0
2047 // in the loop (as opposed to making "axis < 0" to be the base case).
2048 OutputType* output_data = output_literal->data<OutputType>().data();
2049 auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
2050 bool within_src_bounds) {
2051 if (axis == 0) {
2052 // Drop negative frequencies for RFFT.
2053 const int64_t length = fft_lengths[axis];
2054 const int64_t ub = output_is_truncated ? (length / 2) + 1 : length;
2055 for (int64_t i = 0; i < output_lengths[axis]; i++) {
2056 OutputType value = OutputType(0);
2057 // Read data only if the index is within bounds.
2058 if (within_src_bounds && i < ub) {
2059 value = TypeConverter<OutputType, ComplexType>::GetAs(
2060 data[src_index + i * fft_strides[axis]]);
2061 }
2062 output_data[dst_index + i * output_strides[axis]] = value;
2063 }
2064 return true;
2065 }
2066 return false;
2067 };
2068 GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
2069 fft_rank_, output_start, 0, base_case);
2070 }
2071
2072 // Determine the type to use with the CopyDataFromInput<> template above.
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> input_lengths,const absl::Span<const int64_t> input_strides,absl::Span<ComplexType> data)2073 bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
2074 int64_t fft_size,
2075 const absl::Span<const int64_t> fft_lengths,
2076 const absl::Span<const int64_t> fft_strides,
2077 const absl::Span<const int64_t> input_lengths,
2078 const absl::Span<const int64_t> input_strides,
2079 absl::Span<ComplexType> data) {
2080 const bool input_is_float = fft_type_ == FftType::RFFT;
2081 if (input_is_float) {
2082 return CopyDataFromInput<float>(input_literal, input_start, fft_size,
2083 fft_lengths, fft_strides, input_lengths,
2084 input_strides, data);
2085 } else {
2086 return CopyDataFromInput<complex64>(input_literal, input_start, fft_size,
2087 fft_lengths, fft_strides,
2088 input_lengths, input_strides, data);
2089 }
2090 }
2091
2092 // Determine the type to use with the CopyDataToOutput<> template above.
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64_t> fft_lengths,const absl::Span<const int64_t> fft_strides,const absl::Span<const int64_t> output_lengths,const absl::Span<const int64_t> output_strides,Literal * output_literal)2093 void CopyDataToOutput(const absl::Span<ComplexType> data,
2094 int64_t output_start,
2095 const absl::Span<const int64_t> fft_lengths,
2096 const absl::Span<const int64_t> fft_strides,
2097 const absl::Span<const int64_t> output_lengths,
2098 const absl::Span<const int64_t> output_strides,
2099 Literal* output_literal) {
2100 const bool output_is_float = fft_type_ == FftType::IRFFT;
2101 if (output_is_float) {
2102 CopyDataToOutput<float>(data, output_start, fft_lengths, fft_strides,
2103 output_lengths, output_strides, output_literal);
2104 } else {
2105 CopyDataToOutput<complex64>(data, output_start, fft_lengths, fft_strides,
2106 output_lengths, output_strides,
2107 output_literal);
2108 }
2109 }
2110
CheckParameters(const Shape & input_shape,const Shape & output_shape)2111 Status CheckParameters(const Shape& input_shape, const Shape& output_shape) {
2112 // Check FFT parameters.
2113 if (fft_rank_ <= 0) {
2114 return InvalidArgument("Zero or negative FFT rank.");
2115 }
2116 if (*absl::c_min_element(fft_lengths_) < 0) {
2117 return InvalidArgument("Negative FFT length.");
2118 }
2119
2120 // Check input-related values.
2121 TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
2122 if (!input_shape.IsArray()) {
2123 return Unimplemented("Only array input shapes are supported.");
2124 }
2125 auto input_elt_type = input_shape.element_type();
2126 if (fft_type_ == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
2127 return InvalidArgument("Invalid input type: %d, must be %d (float).",
2128 input_elt_type, PrimitiveType::F32);
2129 }
2130 if (fft_type_ != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
2131 return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
2132 input_elt_type, PrimitiveType::C64);
2133 }
2134 const int64_t input_rank = input_shape.rank();
2135 if (input_rank < fft_rank_) {
2136 return InvalidArgument("Input shape rank is smaller than FFT rank.");
2137 }
2138
2139 // Check output-related values.
2140 TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
2141 if (!output_shape.IsArray()) {
2142 return Unimplemented("Only array output shapes are supported.");
2143 }
2144 auto output_elt_type = output_shape.element_type();
2145 if (fft_type_ == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
2146 return InvalidArgument("Invalid output type: %d, must be %d (float).",
2147 output_elt_type, PrimitiveType::F32);
2148 }
2149 if (fft_type_ != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
2150 return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
2151 output_elt_type, PrimitiveType::C64);
2152 }
2153 const int64_t output_rank = output_shape.rank();
2154 if (output_rank < fft_rank_) {
2155 return InvalidArgument("Output shape rank is smaller than FFT rank.");
2156 }
2157
2158 // Consistency of input and output parameters.
2159 if (input_rank != output_rank) {
2160 return InvalidArgument(
2161 "Ranks of input shape and output shape do not match.");
2162 }
2163 for (int64_t dim = 0; dim < input_rank - fft_rank_; dim++) {
2164 if (ShapeUtil::GetDimension(input_shape, dim) !=
2165 ShapeUtil::GetDimension(output_shape, dim)) {
2166 return InvalidArgument(
2167 "Higher dimension lengths of input shape and output shape do not "
2168 "match.");
2169 }
2170 }
2171
2172 return OkStatus();
2173 }
2174
2175 private:
2176 const FftType fft_type_;
2177 const int64_t fft_rank_;
2178 std::vector<int64_t> fft_lengths_;
2179 };
2180
2181 } // namespace
2182
HandleFft(HloInstruction * fft)2183 Status HloEvaluator::HandleFft(HloInstruction* fft) {
2184 const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
2185 Literal output_literal = Literal::CreateFromShape(fft->shape());
2186
2187 FftTransform<complex128> transform(fft);
2188 TF_RETURN_IF_ERROR(transform.ComputeFft(fft, input_literal, &output_literal));
2189 evaluated_[fft] = std::move(output_literal);
2190
2191 return OkStatus();
2192 }
2193
2194 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
2195 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)2196 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
2197 const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
2198 int64_t output_rank = output_shape.dimensions_size();
2199 std::vector<int64_t> index_base(output_rank, 0);
2200 std::vector<int64_t> index_count;
2201 index_count.reserve(output_rank);
2202 for (int64_t i = 0; i < output_rank; i++) {
2203 bool is_output_batch_dim =
2204 !absl::c_binary_search(dim_numbers.offset_dims(), i);
2205 index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
2206 }
2207
2208 return {std::move(index_base), std::move(index_count),
2209 std::vector<int64_t>(output_rank, 1)};
2210 }
2211
2212 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
2213 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64_t output_rank,absl::Span<const int64_t> slice_sizes,const GatherDimensionNumbers & dim_numbers)2214 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
2215 int64_t output_rank, absl::Span<const int64_t> slice_sizes,
2216 const GatherDimensionNumbers& dim_numbers) {
2217 std::vector<int64_t> index_base(output_rank, 0);
2218 std::vector<int64_t> index_count(output_rank, 1);
2219 int64_t slice_sizes_idx = 0;
2220 for (int64_t i = 0; i < output_rank; i++) {
2221 bool is_output_window_dim =
2222 absl::c_binary_search(dim_numbers.offset_dims(), i);
2223 if (is_output_window_dim) {
2224 while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
2225 slice_sizes_idx)) {
2226 slice_sizes_idx++;
2227 }
2228 index_count[i] = slice_sizes[slice_sizes_idx++];
2229 }
2230 }
2231
2232 return {std::move(index_base), std::move(index_count),
2233 std::vector<int64_t>(output_rank, 1)};
2234 }
2235
2236 // This functor computes the contribution of start_indices to an input index
2237 // corresponding to an output index. That is, given an output index I, it picks
2238 // out the batch indices in I and uses them to look up a starting index, G, from
2239 // the start indices tensor, and expands G into the input space according to
2240 // start_index_map.
2241 class OutputBatchIndexToInputIndex {
2242 public:
2243 // The constructor does some setup work that is amortized across all
2244 // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)2245 explicit OutputBatchIndexToInputIndex(
2246 const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
2247 const Shape& output_shape, const Literal* start_indices)
2248 : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
2249 for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
2250 output_dim_is_batch_dims_.push_back(
2251 !absl::c_binary_search(dim_numbers_.offset_dims(), i));
2252 }
2253
2254 for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
2255 int64_t index_of_input_dim_in_index_vector =
2256 std::distance(dim_numbers_.start_index_map().begin(),
2257 absl::c_find(dim_numbers_.start_index_map(), i));
2258 if (index_of_input_dim_in_index_vector ==
2259 dim_numbers_.start_index_map_size()) {
2260 input_dim_value_to_index_vector_.push_back(-1);
2261 } else {
2262 input_dim_value_to_index_vector_.push_back(
2263 index_of_input_dim_in_index_vector);
2264 }
2265 }
2266
2267 index_vector_index_.resize(start_indices_.shape().dimensions_size());
2268 input_index_.resize(input_shape.dimensions_size());
2269 int64_t index_vector_size =
2270 start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
2271 index_vector_.resize(index_vector_size);
2272 }
2273
2274 // Returns the contribution of start_indices to the input index corresponding
2275 // to output_index. See gather_inner_loop_body.
2276 //
2277 // This is conceptually a stateless transformation from output_index to the
2278 // gather input index, but:
2279 //
2280 // - Instead of allocating memory to represent the gather input index on
2281 // every invocation we reuse the same storage for the result
2282 // (input_index_), mutating it in place.
2283 // - Instead of allocating buffers for temporary values like
2284 // index_vector_index_ and index_vector on every invocation, we reuse the
2285 // same storage for all invocations.
2286 //
2287 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> output_index)2288 StatusOr<absl::Span<const int64_t>> operator()(
2289 absl::Span<const int64_t> output_index) {
2290 PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
2291 TF_RETURN_IF_ERROR(FetchIndexVector());
2292 PropagateIndexVectorToInputIndex();
2293 return absl::Span<const int64_t>(input_index_);
2294 }
2295
2296 private:
2297 // Propagates the batch dimensions from the output index into
2298 // index_vector_index_ by mutating index_vector_index_ in place. Does not
2299 // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
2300 // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64_t> output_index)2301 void PropagateOutputIndexGatherDimsToIndexVectorIndex(
2302 absl::Span<const int64_t> output_index) {
2303 int64_t index_vector_index_i = 0;
2304 for (int64_t i = 0, e = output_index.size(); i < e; i++) {
2305 if (!output_dim_is_batch_dims_[i]) {
2306 continue;
2307 }
2308
2309 if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
2310 index_vector_index_i++;
2311 }
2312
2313 index_vector_index_[index_vector_index_i++] = output_index[i];
2314 }
2315 }
2316
2317 // Populates index_vector_ by iterating over start_indices_ according to
2318 // index_vector_index_.
FetchIndexVector()2319 Status FetchIndexVector() {
2320 int64_t index_vector_dim = dim_numbers_.index_vector_dim();
2321 for (int64_t i = 0, e = index_vector_.size(); i < e; i++) {
2322 index_vector_index_[index_vector_dim] = i;
2323 auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
2324 TF_RET_CHECK(start_index.has_value());
2325 index_vector_[i] = *start_index;
2326 }
2327 return OkStatus();
2328 }
2329
2330 // Populates input_index_.
PropagateIndexVectorToInputIndex()2331 void PropagateIndexVectorToInputIndex() {
2332 for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2333 if (input_dim_value_to_index_vector_[i] != -1) {
2334 input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
2335 }
2336
2337 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2338 // remains 0, as set by the constructor.
2339 }
2340 }
2341
2342 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
2343 // the input index from the index vector. See
2344 // PropagateIndexVectorToInputIndex.
2345 std::vector<int64_t> input_dim_value_to_index_vector_;
2346
2347 // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
2348 // dimension.
2349 std::vector<bool> output_dim_is_batch_dims_;
2350
2351 // The buffer into which we construct an index into start_indices_ to fetch
2352 // the index vector.
2353 std::vector<int64_t> index_vector_index_;
2354
2355 // The index vector fetched from start_indices_.
2356 std::vector<int64_t> index_vector_;
2357
2358 // The result computed by this functor. operator() returns a Span into
2359 // this vector.
2360 std::vector<int64_t> input_index_;
2361
2362 const GatherDimensionNumbers& dim_numbers_;
2363 const Literal& start_indices_;
2364 };
2365
2366 // This functor computes the contribution of the offset indices in an output
2367 // index to an input index. That is, given an output index I it picks out the
2368 // output offset indices in I and expands it into an index into the input shape.
2369 class OutputOffsetIndexToInputIndex {
2370 public:
2371 // The constructor does some setup work that is amortized across all
2372 // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)2373 explicit OutputOffsetIndexToInputIndex(
2374 const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
2375 const Shape& output_shape) {
2376 std::vector<int64_t> window_index_to_output_index;
2377 int64_t output_index_count = 0;
2378 for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
2379 if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
2380 window_index_to_output_index.push_back(output_index_count++);
2381 } else {
2382 output_index_count++;
2383 }
2384 }
2385
2386 int64_t window_dim_count = 0;
2387 for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
2388 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
2389 input_dim_value_to_output_index_.push_back(-1);
2390 } else {
2391 input_dim_value_to_output_index_.push_back(
2392 window_index_to_output_index[window_dim_count++]);
2393 }
2394 }
2395
2396 input_index_.resize(input_shape.dimensions_size());
2397 }
2398
2399 // Returns the contribution of the window indices to the input index
2400 // corresponding to output_index. See gather_inner_loop_body.
2401 //
2402 // This is conceptually a stateless transformation from output_index to the
2403 // window input index, but instead of allocating memory to represent the
2404 // gather input index on every invocation we reuse the same storage for the
2405 // result (input_index_), mutating it in place.
2406 //
2407 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> output_index)2408 StatusOr<absl::Span<const int64_t>> operator()(
2409 absl::Span<const int64_t> output_index) {
2410 PropagateOutputIndexWindowDimsToInputIndex(output_index);
2411 return absl::Span<const int64_t>(input_index_);
2412 }
2413
2414 // Returns for a given 'input_dim' the corresponding output dimension index,
2415 // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64_t input_dim)2416 int64_t input_dim_value_to_output_index(int64_t input_dim) {
2417 return input_dim_value_to_output_index_[input_dim];
2418 }
2419
2420 private:
2421 // Propagates window dimensions from the output index to input_index_ by
2422 // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64_t> output_index)2423 void PropagateOutputIndexWindowDimsToInputIndex(
2424 absl::Span<const int64_t> output_index) {
2425 for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2426 if (input_dim_value_to_output_index_[i] != -1) {
2427 input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
2428 }
2429
2430 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2431 // remains 0, as set by the constructor.
2432 }
2433 }
2434
2435 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
2436 // the input index from the output index. See
2437 // PropagateOutputIndexWindowDimsToInputIndex.
2438 std::vector<int64_t> input_dim_value_to_output_index_;
2439
2440 // The result computed by this functor. operator() returns a Span into
2441 // this vector.
2442 std::vector<int64_t> input_index_;
2443 };
2444
2445 // Reshapes the gather indices input to have a trailing degenerate `1` dimension
2446 // if necessary. Hands over the ownership of the newly created literal (if
2447 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64_t index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)2448 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
2449 int64_t index_vector_dim, const Literal& start_indices,
2450 Literal* reshaped_start_indices) {
2451 if (start_indices.shape().dimensions_size() != index_vector_dim) {
2452 return std::cref(start_indices);
2453 }
2454
2455 std::vector<int64_t> new_shape(start_indices.shape().dimensions().begin(),
2456 start_indices.shape().dimensions().end());
2457 new_shape.push_back(1);
2458 TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
2459 start_indices.Reshape(new_shape));
2460 return std::cref(*reshaped_start_indices);
2461 }
2462
HandleGather(HloInstruction * gather)2463 Status HloEvaluator::HandleGather(HloInstruction* gather) {
2464 Literal result = Literal::CreateFromShape(gather->shape());
2465 const Shape& shape = gather->shape();
2466 const GatherDimensionNumbers& dim_numbers =
2467 gather->gather_dimension_numbers();
2468 const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
2469 Literal reshaped_start_indices;
2470 TF_ASSIGN_OR_RETURN(
2471 const Literal& start_indices,
2472 ReshapedGatherIndices(dim_numbers.index_vector_dim(),
2473 GetEvaluatedLiteralFor(gather->operand(1)),
2474 &reshaped_start_indices));
2475
2476 // We iterate over the gather dimensions in the output shape in an outer loop
2477 // nest, and iterate over the window dimensions in the output shape in an
2478 // inner loop nest.
2479
2480 ShapeUtil::IndexIterationSpace start_indices_iteration_space =
2481 IterationSpaceForOutputBatchIndices(shape, dim_numbers);
2482 ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
2483 IterationSpaceForOutputOffsetIndices(
2484 shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
2485
2486 // Scratch buffers that hold an index in the output shape and the
2487 // corresponding index in the input shape.
2488 std::vector<int64_t> input_index(operand.shape().dimensions_size());
2489 std::vector<int64_t> output_index(gather->shape().dimensions_size());
2490 std::vector<int64_t> input_index_clamped(operand.shape().dimensions_size());
2491
2492 OutputBatchIndexToInputIndex output_batch_index_to_input_index(
2493 &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
2494 /*output_shape=*/shape, &start_indices);
2495 OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
2496 gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
2497 /*output_shape=*/shape);
2498
2499 const Shape& operand_shape = operand.shape();
2500 if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2501 evaluated_[gather] = std::move(result);
2502 return OkStatus();
2503 }
2504
2505 auto gather_inner_loop_body =
2506 [&](absl::Span<const int64_t> output_window_index,
2507 absl::Span<const int64_t> input_gather_index,
2508 absl::Span<const int64_t> output_gather_index) -> StatusOr<bool> {
2509 TF_ASSIGN_OR_RETURN(
2510 absl::Span<const int64_t> input_window_index,
2511 output_offset_index_to_input_index(output_window_index));
2512 for (int i = 0, e = output_index.size(); i < e; i++) {
2513 output_index[i] = output_gather_index[i] + output_window_index[i];
2514 DCHECK_LT(output_index[i], shape.dimensions(i));
2515 }
2516 for (int i = 0, e = input_gather_index.size(); i < e; i++) {
2517 int64_t output_dim =
2518 output_offset_index_to_input_index.input_dim_value_to_output_index(i);
2519 // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
2520 // we set the iteration index to 0, so for the purpose of the following
2521 // calculations we can consider the output dimension size to be 1.
2522 int64_t output_dim_size =
2523 output_dim == -1 ? 1 : shape.dimensions(output_dim);
2524 // Clamp the gather index so that the gather region fits in the operand.
2525 // input_index_clamped[i] = clamp(input_gather_index[i], 0,
2526 // operand_shape.dimensions(i) -
2527 // output_dim_size);
2528 input_index_clamped[i] =
2529 std::min(operand_shape.dimensions(i) - output_dim_size,
2530 std::max(int64_t{0}, input_gather_index[i]));
2531 }
2532 for (int i = 0, e = input_index.size(); i < e; i++) {
2533 input_index[i] = input_index_clamped[i] + input_window_index[i];
2534 DCHECK_GE(input_index[i], 0);
2535 DCHECK_LT(input_index[i], operand_shape.dimensions(i));
2536 }
2537 TF_RETURN_IF_ERROR(
2538 result.CopyElementFrom(operand, input_index, output_index));
2539 return true;
2540 };
2541
2542 auto gather_outer_loop_body =
2543 [&](absl::Span<const int64_t> output_gather_index) -> StatusOr<bool> {
2544 TF_ASSIGN_OR_RETURN(absl::Span<const int64_t> input_gather_index,
2545 output_batch_index_to_input_index(output_gather_index));
2546 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2547 shape, offset_indices_iteration_space,
2548 std::bind(gather_inner_loop_body, std::placeholders::_1,
2549 input_gather_index, output_gather_index)));
2550 return true;
2551 };
2552
2553 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2554 shape, start_indices_iteration_space, gather_outer_loop_body));
2555 evaluated_[gather] = std::move(result);
2556 return OkStatus();
2557 }
2558
2559 namespace {
2560 // Reshapes the scatter indices input to have a trailing degenerate `1`
2561 // dimension if necessary. Hands over the ownership of the newly created
2562 // literal (if there is one) to `reshaped_indices`.
ReshapedScatterIndices(int64_t index_vector_dim,const Literal & indices,Literal * reshaped_indices)2563 StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
2564 int64_t index_vector_dim, const Literal& indices,
2565 Literal* reshaped_indices) {
2566 if (indices.shape().dimensions_size() != index_vector_dim) {
2567 return std::cref(indices);
2568 }
2569
2570 std::vector<int64_t> new_shape(indices.shape().dimensions().begin(),
2571 indices.shape().dimensions().end());
2572 new_shape.push_back(1);
2573 TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
2574 return std::cref(*reshaped_indices);
2575 }
2576
2577 template <bool kForUpdateWindowIndices>
GetIterationSpaceImpl(absl::Span<const int64_t> updates_dims,const ScatterDimensionNumbers & dim_numbers)2578 ShapeUtil::IndexIterationSpace GetIterationSpaceImpl(
2579 absl::Span<const int64_t> updates_dims,
2580 const ScatterDimensionNumbers& dim_numbers) {
2581 int64_t updates_rank = updates_dims.size();
2582 std::vector<int64_t> index_base(updates_rank, 0);
2583 std::vector<int64_t> index_count(updates_rank, 1);
2584 for (int64_t i = 0; i < updates_rank; i++) {
2585 // Use if constexpr when we can use c++17 or above.
2586 if (kForUpdateWindowIndices) {
2587 bool is_update_window_dim =
2588 absl::c_binary_search(dim_numbers.update_window_dims(), i);
2589 if (is_update_window_dim) {
2590 index_count[i] = updates_dims[i];
2591 }
2592 } else {
2593 bool is_update_scatter_dim =
2594 !absl::c_binary_search(dim_numbers.update_window_dims(), i);
2595 if (is_update_scatter_dim) {
2596 index_count[i] = updates_dims[i];
2597 }
2598 }
2599 }
2600 return {std::move(index_base), std::move(index_count),
2601 std::vector<int64_t>(updates_rank, 1)};
2602 }
2603
2604 // Returns an ShapeUtil::IndexIterationSpace that iterates over the update
2605 // scatter dimensions while keeping the rest of the update dimensions clamped
2606 // to 0.
IterationSpaceForUpdateScatterIndices(absl::Span<const int64_t> updates_dims,const ScatterDimensionNumbers & dim_numbers)2607 ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices(
2608 absl::Span<const int64_t> updates_dims,
2609 const ScatterDimensionNumbers& dim_numbers) {
2610 return GetIterationSpaceImpl</*kForUpdateWindowIndices=*/false>(updates_dims,
2611 dim_numbers);
2612 }
2613
2614 // Return an ShapeUtil::IndexIterationSpace that iterates over the update
2615 // window dimensions while keeping the rest of the update dimensions clamped
2616 // to 0.
IterationSpaceForUpdateWindowIndices(absl::Span<const int64_t> updates_dims,const ScatterDimensionNumbers & dim_numbers)2617 ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices(
2618 absl::Span<const int64_t> updates_dims,
2619 const ScatterDimensionNumbers& dim_numbers) {
2620 return GetIterationSpaceImpl</*kForUpdateWindowIndices=*/true>(updates_dims,
2621 dim_numbers);
2622 }
2623
2624 // This functor computes the contribution of scatter_indices to an input index
2625 // corresponding to an update index. That is, given an update index I, it
2626 // picks out the scatter indices in I and uses them to look up a scatter
2627 // index, S, from the scatter indices tensor, and expands S into the input
2628 // space according to scatter_dims_to_operand_dims.
2629 //
2630 // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex
2631 // that does the corresponding function for Gather.
2632 class UpdateScatterIndexToInputIndex {
2633 public:
2634 // The constructor does some setup work that is amortized across all
2635 // iterations.
UpdateScatterIndexToInputIndex(const ScatterDimensionNumbers & dim_numbers,int64_t input_rank,int64_t updates_rank,const Literal * scatter_indices)2636 explicit UpdateScatterIndexToInputIndex(
2637 const ScatterDimensionNumbers& dim_numbers, int64_t input_rank,
2638 int64_t updates_rank, const Literal* scatter_indices)
2639 : dim_numbers_(dim_numbers), scatter_indices_(*scatter_indices) {
2640 for (int64_t i = 0; i < updates_rank; i++) {
2641 update_dim_is_scatter_dims_.push_back(
2642 !absl::c_binary_search(dim_numbers_.update_window_dims(), i));
2643 }
2644
2645 for (int64_t i = 0; i < input_rank; i++) {
2646 int64_t index_of_input_dim_in_index_vector =
2647 FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i);
2648 if (index_of_input_dim_in_index_vector ==
2649 dim_numbers_.scatter_dims_to_operand_dims_size()) {
2650 input_dim_value_to_index_vector_.push_back(-1);
2651 } else {
2652 input_dim_value_to_index_vector_.push_back(
2653 index_of_input_dim_in_index_vector);
2654 }
2655 }
2656
2657 index_vector_index_.resize(scatter_indices_.shape().dimensions_size());
2658 input_index_.resize(input_rank);
2659 int64_t index_vector_size =
2660 scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
2661 index_vector_.resize(index_vector_size);
2662 }
2663
2664 // Returns the contribution of scatter_indices to the input index
2665 // corresponding to update_index. See scatter_inner_loop_body.
2666 //
2667 // This is conceptually a stateless transformation from update_index to the
2668 // scatter input index, but:
2669 //
2670 // - Instead of allocating memory to represent the scatter input index on
2671 // every invocation we reuse the same storage for the result
2672 // (input_index_), mutating it in place.
2673 // - Instead of allocating buffers for temporary values like
2674 // index_vector_index_ and index_vector on every invocation, we reuse the
2675 // same storage for all invocations.
2676 //
2677 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> update_index)2678 StatusOr<absl::Span<const int64_t>> operator()(
2679 absl::Span<const int64_t> update_index) {
2680 PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
2681 TF_RETURN_IF_ERROR(FetchIndexVector());
2682 PropagateIndexVectorToInputIndex();
2683 return absl::Span<const int64_t>(input_index_);
2684 }
2685
2686 private:
2687 // Propagates the scatter index dimensions from the update index into
2688 // index_vector_index_ by mutating index_vector_index_ in place. Does not
2689 // update the dim_numbers.index_vector_dim() dimension -- that's the
2690 // dimension we iterate over in FetchIndexVector.
PropagateUpdateIndexScatterDimsToIndexVectorIndex(absl::Span<const int64_t> update_index)2691 void PropagateUpdateIndexScatterDimsToIndexVectorIndex(
2692 absl::Span<const int64_t> update_index) {
2693 int64_t index_vector_index_i = 0;
2694 for (int64_t i = 0, e = update_index.size(); i < e; i++) {
2695 if (!update_dim_is_scatter_dims_[i]) {
2696 continue;
2697 }
2698
2699 if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
2700 index_vector_index_i++;
2701 }
2702
2703 index_vector_index_[index_vector_index_i++] = update_index[i];
2704 }
2705 }
2706
2707 // Populates index_vector_ by iterating over scatter_indices_ according to
2708 // index_vector_index_.
FetchIndexVector()2709 Status FetchIndexVector() {
2710 int64_t index_vector_dim = dim_numbers_.index_vector_dim();
2711 for (int64_t i = 0, e = index_vector_.size(); i < e; i++) {
2712 index_vector_index_[index_vector_dim] = i;
2713 index_vector_[i] =
2714 *scatter_indices_.GetIntegralAsS64(index_vector_index_);
2715 }
2716 return OkStatus();
2717 }
2718
2719 // Populates input_index_.
PropagateIndexVectorToInputIndex()2720 void PropagateIndexVectorToInputIndex() {
2721 for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2722 if (input_dim_value_to_index_vector_[i] != -1) {
2723 input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
2724 }
2725
2726 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2727 // remains 0, as set by the constructor.
2728 }
2729 }
2730
2731 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i
2732 // of the input index from the index vector. See
2733 // PropagateIndexVectorToInputIndex.
2734 std::vector<int64_t> input_dim_value_to_index_vector_;
2735
2736 // update_dim_is_scatter_dims_[i] is true iff the update index i is a
2737 // scatter dimension.
2738 std::vector<bool> update_dim_is_scatter_dims_;
2739
2740 // The buffer into which we construct an index into scatter_indices_ to
2741 // fetch the index vector.
2742 std::vector<int64_t> index_vector_index_;
2743
2744 // The index vector fetched from scatter_indices_.
2745 std::vector<int64_t> index_vector_;
2746
2747 // The result computed by this functor. operator() returns a Span
2748 // into this vector.
2749 std::vector<int64_t> input_index_;
2750
2751 const ScatterDimensionNumbers& dim_numbers_;
2752 const Literal& scatter_indices_;
2753 };
2754
2755 // This functor computes the contribution of the window indices in an update
2756 // index to an input index. That is, given an update index I it picks out the
2757 // update window indices in I and expands it into a window index into the
2758 // input shape.
2759 //
2760 // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex
2761 // that does the corresponding function for Gather.
2762 class UpdateWindowIndexToInputIndex {
2763 public:
2764 // The constructor does some setup work that is amortized across all
2765 // iterations.
UpdateWindowIndexToInputIndex(const ScatterDimensionNumbers & dim_numbers,int64_t input_rank,int64_t update_rank)2766 explicit UpdateWindowIndexToInputIndex(
2767 const ScatterDimensionNumbers& dim_numbers, int64_t input_rank,
2768 int64_t update_rank) {
2769 std::vector<int64_t> window_index_to_update_index;
2770 int64_t update_index_count = 0;
2771 for (int64_t i = 0; i < update_rank; i++) {
2772 if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
2773 window_index_to_update_index.push_back(update_index_count++);
2774 } else {
2775 update_index_count++;
2776 }
2777 }
2778
2779 int64_t window_dim_count = 0;
2780 for (int64_t i = 0; i < input_rank; i++) {
2781 if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
2782 input_dim_value_to_update_index_.push_back(-1);
2783 } else {
2784 input_dim_value_to_update_index_.push_back(
2785 window_index_to_update_index[window_dim_count++]);
2786 }
2787 }
2788
2789 input_index_.resize(input_rank);
2790 }
2791
2792 // Returns the contribution of the window indices to the input index
2793 // corresponding to update_index. See scatter_inner_loop_body.
2794 //
2795 // This is conceptually a stateless transformation from update_index to the
2796 // window input index, but instead of allocating memory to represent the
2797 // scatter input index on every invocation we reuse the same storage for the
2798 // result (input_index_), mutating it in place.
2799 //
2800 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64_t> update_index)2801 StatusOr<absl::Span<const int64_t>> operator()(
2802 absl::Span<const int64_t> update_index) {
2803 PropagateUpdateIndexWindowDimsToInputIndex(update_index);
2804 return absl::Span<const int64_t>(input_index_);
2805 }
2806
2807 // Returns for a given 'input_dim' the corresponding update dimension index,
2808 // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_update_index(int64_t input_dim)2809 int64_t input_dim_value_to_update_index(int64_t input_dim) {
2810 return input_dim_value_to_update_index_[input_dim];
2811 }
2812
2813 private:
2814 // Propagates window dimensions from the update index to input_index_ by
2815 // mutating input_index_ in place.
PropagateUpdateIndexWindowDimsToInputIndex(absl::Span<const int64_t> update_index)2816 void PropagateUpdateIndexWindowDimsToInputIndex(
2817 absl::Span<const int64_t> update_index) {
2818 for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
2819 if (input_dim_value_to_update_index_[i] != -1) {
2820 input_index_[i] = update_index[input_dim_value_to_update_index_[i]];
2821 }
2822
2823 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
2824 // remains 0, as set by the constructor.
2825 }
2826 }
2827
2828 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i
2829 // of the input index from the update index. See
2830 // PropagateUpdateIndexWindowDimsToInputIndex.
2831 std::vector<int64_t> input_dim_value_to_update_index_;
2832
2833 // The result computed by this functor. operator() returns a Span
2834 // into this vector.
2835 std::vector<int64_t> input_index_;
2836 };
2837 } // namespace
2838
HandleScatter(HloInstruction * hlo)2839 Status HloEvaluator::HandleScatter(HloInstruction* hlo) {
2840 auto* scatter = DynCast<HloScatterInstruction>(hlo);
2841 const ScatterDimensionNumbers& dim_numbers =
2842 scatter->scatter_dimension_numbers();
2843 absl::InlinedVector<const Literal*, 1> operands;
2844 operands.reserve(scatter->scatter_operand_count());
2845 for (HloInstruction* operand_inst : scatter->scatter_operands()) {
2846 operands.push_back(&GetEvaluatedLiteralFor(operand_inst));
2847 }
2848 Literal reshaped_scatter_indices;
2849 TF_ASSIGN_OR_RETURN(
2850 const Literal& scatter_indices,
2851 ReshapedScatterIndices(dim_numbers.index_vector_dim(),
2852 GetEvaluatedLiteralFor(scatter->scatter_indices()),
2853 &reshaped_scatter_indices));
2854 absl::InlinedVector<const Literal*, 1> updates;
2855 updates.reserve(operands.size());
2856 for (HloInstruction* updates_inst : scatter->scatter_updates()) {
2857 updates.push_back(&GetEvaluatedLiteralFor(updates_inst));
2858 }
2859 auto updates_dims = updates[0]->shape().dimensions();
2860 auto operand_dims = operands[0]->shape().dimensions();
2861
2862 ShapeUtil::IndexIterationSpace scatter_indices_iteration_space =
2863 IterationSpaceForUpdateScatterIndices(updates_dims, dim_numbers);
2864 ShapeUtil::IndexIterationSpace window_indices_iteration_space =
2865 IterationSpaceForUpdateWindowIndices(updates_dims, dim_numbers);
2866
2867 std::vector<int64_t> input_index(operand_dims.size());
2868 std::vector<int64_t> update_index(updates_dims.size());
2869
2870 UpdateScatterIndexToInputIndex update_scatter_index_to_input_index(
2871 scatter->scatter_dimension_numbers(),
2872 /*input_rank=*/operand_dims.size(), updates_dims.size(),
2873 &scatter_indices);
2874 UpdateWindowIndexToInputIndex update_window_index_to_input_index(
2875 scatter->scatter_dimension_numbers(),
2876 /*input_rank=*/operand_dims.size(), updates_dims.size());
2877
2878 // Initialize the result with the operand. This makes it easier to handle
2879 // the updates even when the indices are repeated.
2880 Literal result = operands.size() > 1 ? LiteralUtil::MakeTuple(operands)
2881 : operands[0]->Clone();
2882 auto maybe_slice = [](MutableLiteralBase& literal, int idx) {
2883 if (literal.shape().IsTuple()) {
2884 return MutableBorrowingLiteral(&literal, {idx});
2885 }
2886 DCHECK_EQ(idx, 0);
2887 return MutableBorrowingLiteral(&literal);
2888 };
2889
2890 HloEvaluator embedded_evaluator;
2891 auto scatter_inner_loop_body =
2892 [&](absl::Span<const int64_t> update_window_index,
2893 absl::Span<const int64_t> input_scatter_index,
2894 absl::Span<const int64_t> update_scatter_index) -> StatusOr<bool> {
2895 TF_ASSIGN_OR_RETURN(
2896 absl::Span<const int64_t> input_window_index,
2897 update_window_index_to_input_index(update_window_index));
2898 for (int i = 0, e = update_index.size(); i < e; i++) {
2899 update_index[i] = update_scatter_index[i] + update_window_index[i];
2900 DCHECK_LT(update_index[i], updates_dims[i]);
2901 }
2902 for (int i = 0, e = input_scatter_index.size(); i < e; i++) {
2903 int64_t update_dim =
2904 update_window_index_to_input_index.input_dim_value_to_update_index(i);
2905 // If 'update_dim' is -1, it means 'i' is an elided window dim. This
2906 // means we set the iteration index to 0, so for the purpose of the
2907 // following calculations we can consider the update dimension size to
2908 // be 1.
2909 int64_t update_dim_size = update_dim == -1 ? 1 : updates_dims[update_dim];
2910 // If any part of the update region is out-of-bounds, then do not
2911 // perform any update on the input.
2912 if ((input_scatter_index[i] < 0) ||
2913 (input_scatter_index[i] > operand_dims[i] - update_dim_size)) {
2914 return true;
2915 }
2916 }
2917 for (int i = 0, e = input_index.size(); i < e; i++) {
2918 input_index[i] = input_scatter_index[i] + input_window_index[i];
2919 }
2920
2921 absl::InlinedVector<Literal, 2> to_apply_args;
2922 to_apply_args.reserve(operands.size() + updates.size());
2923 for (int i = 0, n = operands.size(); i < n; ++i) {
2924 to_apply_args.push_back(
2925 LiteralUtil::GetScalarLiteral(maybe_slice(result, i), input_index));
2926 }
2927 for (int i = 0, n = operands.size(); i < n; ++i) {
2928 to_apply_args.push_back(
2929 LiteralUtil::GetScalarLiteral(*updates[i], update_index));
2930 }
2931 Literal updated_result =
2932 embedded_evaluator.Evaluate(*scatter->to_apply(), to_apply_args)
2933 .value();
2934 // Clear visit states so that the we can use the evaluate again on the
2935 // same computation.
2936 embedded_evaluator.ResetVisitStates();
2937 for (int i = 0, n = operands.size(); i < n; ++i) {
2938 auto result_slice = maybe_slice(result, i);
2939 LiteralUtil::SetScalarLiteral(result_slice, input_index,
2940 maybe_slice(updated_result, i));
2941 }
2942 return true;
2943 };
2944
2945 auto scatter_outer_loop_body =
2946 [&](absl::Span<const int64_t> update_scatter_index) -> StatusOr<bool> {
2947 TF_ASSIGN_OR_RETURN(
2948 absl::Span<const int64_t> input_scatter_index,
2949 update_scatter_index_to_input_index(update_scatter_index));
2950 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2951 updates[0]->shape(), window_indices_iteration_space,
2952 [&](absl::Span<const int64_t> update_window_index) {
2953 return scatter_inner_loop_body(
2954 update_window_index, input_scatter_index, update_scatter_index);
2955 }));
2956 return true;
2957 };
2958
2959 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2960 updates[0]->shape(), scatter_indices_iteration_space,
2961 scatter_outer_loop_body));
2962 evaluated_[scatter] = std::move(result);
2963 return OkStatus();
2964 }
2965
HandleBroadcast(HloInstruction * broadcast)2966 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
2967 const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
2968 TF_RET_CHECK(broadcast->shape().element_type() ==
2969 operand.shape().element_type())
2970 << " broadcast from a different data type is not supported";
2971 TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
2972 << "broadcast dimensions is of size: " << broadcast->dimensions().size()
2973 << " and rank of operand_to_broadcast is: " << operand.shape().rank();
2974 // Checks that operand's dimensions are the same as the broadcast's
2975 // dimensions along the dimensions to be broadcasted.
2976 for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
2977 auto operand_dim_size = operand.shape().dimensions(i);
2978 auto broadcast_dim_size =
2979 broadcast->shape().dimensions(broadcast->dimensions(i));
2980 TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
2981 "Operand dimension %d is broadcast to output dimension %d, but the "
2982 "sizes of these two dims do not match (%d vs %d): %s",
2983 i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
2984 broadcast->ToString());
2985 }
2986
2987 TF_ASSIGN_OR_RETURN(
2988 evaluated_[broadcast],
2989 operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
2990
2991 return OkStatus();
2992 }
2993
HandleAfterAll(HloInstruction * after_all)2994 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
2995 evaluated_[after_all] = LiteralUtil::CreateToken();
2996 return OkStatus();
2997 }
2998
HandleAddDependency(HloInstruction * add_dependency)2999 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
3000 // AddDedendency just forwards its zero-th operand.
3001 evaluated_[add_dependency] =
3002 GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
3003 return OkStatus();
3004 }
3005
HandleGetTupleElement(HloInstruction * get_tuple_element)3006 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
3007 const auto result_shape = get_tuple_element->shape();
3008 const int64_t index = get_tuple_element->tuple_index();
3009
3010 auto operand = get_tuple_element->operand(0);
3011 TF_ASSIGN_OR_RETURN(
3012 auto inferred_return_shape,
3013 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
3014 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
3015 << "return shape set to: " << ShapeUtil::HumanString(result_shape)
3016 << " but is inferred to be: "
3017 << ShapeUtil::HumanString(inferred_return_shape);
3018
3019 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
3020
3021 evaluated_[get_tuple_element] =
3022 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
3023 return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
3024 /*dest_shape_index=*/{},
3025 /*src_shape_index=*/{index});
3026 }
3027
HandleCopy(HloInstruction * copy)3028 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
3029 TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
3030 evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
3031 return OkStatus();
3032 }
3033
HandleAsyncStart(HloInstruction * async_start)3034 Status HloEvaluator::HandleAsyncStart(HloInstruction* async_start) {
3035 std::vector<const Literal*> arg_literals;
3036 arg_literals.reserve(async_start->operands().size());
3037 for (auto operand : async_start->operands()) {
3038 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
3039 arg_literals.push_back(&arg_literal);
3040 }
3041
3042 HloEvaluator embedded_evaluator;
3043 embedded_evaluator.set_dynamic_dimension_inference(
3044 dynamic_dimension_inference_);
3045 TF_ASSIGN_OR_RETURN(
3046 Literal result,
3047 embedded_evaluator.Evaluate(*async_start->async_wrapped_computation(),
3048 arg_literals));
3049
3050 evaluated_[async_start] = Literal(async_start->shape());
3051 // Copy the operand values to the index {0, i} of the output.
3052 for (int i = 0; i < arg_literals.size(); ++i) {
3053 TF_RETURN_IF_ERROR(evaluated_[async_start].CopyFrom(
3054 *arg_literals[i], /*dest_shape_index=*/{0, i},
3055 /*src_shape_index=*/{}));
3056 }
3057 // Move the output value to the index {1} of the output.
3058 TF_RETURN_IF_ERROR(evaluated_[async_start].MoveFrom(
3059 std::move(result), /*dest_shape_index=*/{1}));
3060
3061 return OkStatus();
3062 }
3063
HandleAsyncUpdate(HloInstruction * async_update)3064 Status HloEvaluator::HandleAsyncUpdate(HloInstruction* async_update) {
3065 const Literal& operand_tuple_literal =
3066 GetEvaluatedLiteralFor(async_update->operand(0));
3067 evaluated_[async_update] = Literal(async_update->shape());
3068 TF_RETURN_IF_ERROR(evaluated_[async_update].CopyFrom(operand_tuple_literal,
3069 /*dest_shape_index=*/{},
3070 /*src_shape_index=*/{}));
3071 return OkStatus();
3072 }
3073
HandleAsyncDone(HloInstruction * async_done)3074 Status HloEvaluator::HandleAsyncDone(HloInstruction* async_done) {
3075 const Literal& operand_tuple_literal =
3076 GetEvaluatedLiteralFor(async_done->operand(0));
3077 evaluated_[async_done] = Literal(async_done->shape());
3078 TF_RETURN_IF_ERROR(evaluated_[async_done].CopyFrom(operand_tuple_literal,
3079 /*dest_shape_index=*/{},
3080 /*src_shape_index=*/{1}));
3081 return OkStatus();
3082 }
3083
HandleCopyStart(HloInstruction * copy_start)3084 Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
3085 if (copy_start->user_count() != 1 ||
3086 copy_start->users().at(0)->opcode() != HloOpcode::kCopyDone) {
3087 return tensorflow::errors::FailedPrecondition(
3088 "Cannot evaluate a kCopyStart that doesn't have a single kCopyDone "
3089 "user.");
3090 }
3091
3092 // The context in index {2} is undefined, but since we can't represent
3093 // undefined values using a Literal, we just use 0. This should be safe though
3094 // since we ensure that the only user of a kCopyStart is a kCopyDone which
3095 // consumes the context. Also note that MakeTuple copies its arguments, so
3096 // this is memory-safe.
3097 const Literal context_literal = LiteralUtil::CreateR0<uint32_t>(0);
3098 evaluated_[copy_start] = LiteralUtil::MakeTuple(
3099 {&GetEvaluatedLiteralFor(copy_start->operand(0)),
3100 &GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
3101 return OkStatus();
3102 }
3103
HandleCopyDone(HloInstruction * copy_done)3104 Status HloEvaluator::HandleCopyDone(HloInstruction* copy_done) {
3105 const HloInstruction* operand = copy_done->operand(0);
3106 if (operand->opcode() != HloOpcode::kCopyStart) {
3107 return tensorflow::errors::FailedPrecondition(
3108 "Cannot evaluate a kCopyDone that doesn't have a kCopyStart as "
3109 "operand.");
3110 }
3111
3112 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
3113 evaluated_[copy_done] =
3114 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), /*index=*/0));
3115 TF_RETURN_IF_ERROR(evaluated_[copy_done].CopyFrom(operand_tuple_literal,
3116 /*dest_shape_index=*/{},
3117 /*src_shape_index=*/{0}));
3118 return OkStatus();
3119 }
3120
HandleCall(HloInstruction * call)3121 Status HloEvaluator::HandleCall(HloInstruction* call) {
3122 auto* computation = call->to_apply();
3123 auto operands = call->operands();
3124
3125 std::vector<const Literal*> arg_literals;
3126 arg_literals.reserve(operands.size());
3127 for (auto operand : operands) {
3128 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
3129 arg_literals.push_back(&arg_literal);
3130 }
3131
3132 std::unique_ptr<HloEvaluator> embedded_evaluator =
3133 CreateEmbedded(max_loop_iterations_);
3134 embedded_evaluator->set_dynamic_dimension_inference(
3135 dynamic_dimension_inference_);
3136 TF_ASSIGN_OR_RETURN(Literal result,
3137 embedded_evaluator->Evaluate(*computation, arg_literals));
3138
3139 evaluated_[call] = std::move(result);
3140 return OkStatus();
3141 }
3142
HandleFusion(HloInstruction * fusion)3143 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
3144 HloModuleConfig config;
3145 // Attach cloned computation to an empty HLO module so the existing ones are
3146 // not modified.
3147 HloModule empty_hlo_module("EmptyModuleForFusion", config);
3148 HloCloneContext context(&empty_hlo_module);
3149 auto cloned_fused_computation =
3150 fusion->fused_instructions_computation()->Clone(
3151 /*suffix=*/"clone_with_layout", &context);
3152 for (auto* instruction : cloned_fused_computation->instructions()) {
3153 if (!LayoutUtil::HasLayout(instruction->shape())) {
3154 LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
3155 }
3156 }
3157 auto readded_computation =
3158 empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
3159
3160 auto operands = fusion->operands();
3161 std::vector<const Literal*> arg_literals;
3162 arg_literals.reserve(operands.size());
3163 for (auto operand : operands) {
3164 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
3165 arg_literals.push_back(&arg_literal);
3166 }
3167
3168 std::unique_ptr<HloEvaluator> embedded_evaluator =
3169 CreateEmbedded(max_loop_iterations_);
3170 embedded_evaluator->set_dynamic_dimension_inference(
3171 dynamic_dimension_inference_);
3172 TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator->Evaluate(
3173 *readded_computation, arg_literals));
3174
3175 evaluated_[fusion] = std::move(result);
3176 return OkStatus();
3177 }
3178
HandleConditional(HloInstruction * conditional)3179 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
3180 const auto& branch_index_literal =
3181 GetEvaluatedLiteralFor(conditional->operand(0));
3182 int branch_index;
3183 if (conditional->operand(0)->shape().element_type() == PRED) {
3184 branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
3185 } else {
3186 branch_index = branch_index_literal.Get<int32_t>({});
3187 if (branch_index < 0 || branch_index >= conditional->branch_count()) {
3188 branch_index = conditional->branch_count() - 1;
3189 }
3190 }
3191 const auto& branch_computation_arg =
3192 GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
3193
3194 std::unique_ptr<HloEvaluator> embedded_evaluator =
3195 CreateEmbedded(max_loop_iterations_);
3196 embedded_evaluator->set_dynamic_dimension_inference(
3197 dynamic_dimension_inference_);
3198 TF_ASSIGN_OR_RETURN(Literal result,
3199 embedded_evaluator->Evaluate(
3200 *conditional->branch_computation(branch_index),
3201 {&branch_computation_arg}));
3202
3203 evaluated_[conditional] = std::move(result);
3204 return OkStatus();
3205 }
3206
HandleSelect(HloInstruction * select)3207 Status HloEvaluator::HandleSelect(HloInstruction* select) {
3208 const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
3209 const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
3210 const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
3211
3212 // If predicate is of scalar type, no element-wise selection would be needed.
3213 if (ShapeUtil::IsScalar(pred.shape())) {
3214 if (pred.Get<bool>({})) {
3215 evaluated_[select] = on_true.Clone();
3216 } else {
3217 evaluated_[select] = on_false.Clone();
3218 }
3219 return OkStatus();
3220 }
3221
3222 return DefaultAction(select);
3223 }
3224
3225 namespace {
3226
CreateScalarLiteral(int64_t value,PrimitiveType element_type)3227 StatusOr<Literal> CreateScalarLiteral(int64_t value,
3228 PrimitiveType element_type) {
3229 Literal result;
3230 switch (element_type) {
3231 case S8:
3232 result = LiteralUtil::CreateR0(static_cast<int8_t>(value));
3233 break;
3234 case U8:
3235 result = LiteralUtil::CreateR0(static_cast<uint8_t>(value));
3236 break;
3237 case S16:
3238 result = LiteralUtil::CreateR0(static_cast<int16_t>(value));
3239 break;
3240 case U16:
3241 result = LiteralUtil::CreateR0(static_cast<uint16_t>(value));
3242 break;
3243 case S32:
3244 result = LiteralUtil::CreateR0(static_cast<int32_t>(value));
3245 break;
3246 case U32:
3247 result = LiteralUtil::CreateR0(static_cast<uint32_t>(value));
3248 break;
3249 case S64:
3250 result = LiteralUtil::CreateR0(static_cast<int64_t>(value));
3251 break;
3252 case U64:
3253 result = LiteralUtil::CreateR0(static_cast<uint64_t>(value));
3254 break;
3255 default:
3256 return InvalidArgument("Unsupported element type.");
3257 }
3258 return result;
3259 }
3260
3261 // Parses the while loop if it matches one of the known patterns. Returns the
3262 // value of the loop induction variable after the loop execution if the loop is
3263 // static.
TryParseAndEvaluateWhileInductionVar(HloInstruction * while_hlo)3264 StatusOr<Literal> TryParseAndEvaluateWhileInductionVar(
3265 HloInstruction* while_hlo) {
3266 std::optional<ParsedWhileLoop> parsed_while_loop =
3267 PatternMatchParseWhileLoop(while_hlo);
3268 if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic()) {
3269 return FailedPrecondition(
3270 "Cannot evaluate a while loop's induction variable since the loop "
3271 "does not match a known loop pattern or the loop is not static.");
3272 }
3273 int64_t induction_var_value =
3274 parsed_while_loop->static_while_loop->induction_var_init_value +
3275 parsed_while_loop->static_while_loop->trip_count *
3276 parsed_while_loop->static_while_loop->step_size;
3277 Shape result_shape = while_hlo->shape().tuple_shapes(
3278 parsed_while_loop->static_while_loop->induction_var_index);
3279 TF_ASSIGN_OR_RETURN(
3280 Literal result,
3281 CreateScalarLiteral(induction_var_value, result_shape.element_type()));
3282 std::vector<Literal*> while_result_element_ptrs;
3283 while_result_element_ptrs.reserve(while_hlo->shape().tuple_shapes_size());
3284 std::vector<Literal> while_result_elements(
3285 while_hlo->shape().tuple_shapes_size());
3286 for (int i = 0; i < while_hlo->shape().tuple_shapes_size(); ++i) {
3287 if (i == parsed_while_loop->static_while_loop->induction_var_index) {
3288 while_result_element_ptrs.push_back(&result);
3289 } else {
3290 const Shape& shape = while_hlo->shape().tuple_shapes(i);
3291 while_result_elements[i] =
3292 Literal::CreateFromShapeWithUnknownLeafArrays(shape);
3293 }
3294 }
3295 return LiteralUtil::MakeTuple(while_result_element_ptrs);
3296 }
3297
3298 } // namespace
3299
HandleWhile(HloInstruction * while_hlo)3300 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
3301 HloComputation* cond_comp = while_hlo->while_condition();
3302 HloComputation* body_comp = while_hlo->while_body();
3303 // Initialize the loop carried valued with the input to the While instruction.
3304 auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
3305 if (!lcv.IsKnown()) {
3306 std::optional<ParsedWhileLoop> parsed_while_loop =
3307 PatternMatchParseWhileLoop(while_hlo);
3308 evaluated_[while_hlo] =
3309 Literal::CreateFromShapeWithUnknownLeafArrays(while_hlo->shape());
3310 if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic() ||
3311 visitor_shape_index_.size() != 1 ||
3312 parsed_while_loop->static_while_loop->induction_var_index !=
3313 visitor_shape_index_[0]) {
3314 return OkStatus();
3315 }
3316 Shape induction_var_shape =
3317 ShapeUtil::GetSubshape(while_hlo->shape(), visitor_shape_index_);
3318 int64_t trip_count = parsed_while_loop->static_while_loop->trip_count;
3319 TF_ASSIGN_OR_RETURN(
3320 Literal induction_var_val,
3321 CreateScalarLiteral(trip_count, induction_var_shape.element_type()));
3322 TF_RETURN_IF_ERROR(evaluated_[while_hlo].CopyFrom(
3323 induction_var_val, /*dest_shape_index=*/visitor_shape_index_,
3324 /*src_shape_index=*/{}));
3325 return OkStatus();
3326 }
3327 bool keep_going = true;
3328 int64_t iteration_count = 0;
3329 std::unique_ptr<HloEvaluator> cond_evaluator =
3330 CreateEmbedded(max_loop_iterations_);
3331 cond_evaluator->set_dynamic_dimension_inference(dynamic_dimension_inference_);
3332 std::unique_ptr<HloEvaluator> loop_body_evaluator =
3333 CreateEmbedded(max_loop_iterations_);
3334 loop_body_evaluator->set_dynamic_dimension_inference(
3335 dynamic_dimension_inference_);
3336 while (keep_going) {
3337 if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
3338 StatusOr<Literal> result =
3339 TryParseAndEvaluateWhileInductionVar(while_hlo);
3340 if (result.ok()) {
3341 lcv = std::move(result).value();
3342 break;
3343 } else {
3344 return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
3345 while_hlo->name(), max_loop_iterations_);
3346 }
3347 }
3348 TF_ASSIGN_OR_RETURN(auto cond_val,
3349 cond_evaluator->Evaluate(*cond_comp, {&lcv}));
3350 keep_going = cond_val.GetFirstElement<bool>();
3351 if (keep_going) {
3352 TF_ASSIGN_OR_RETURN(auto body_val,
3353 loop_body_evaluator->Evaluate(*body_comp, {&lcv}));
3354 VLOG(3) << "Loop iteration result: " << body_val.ToString();
3355 lcv = std::move(body_val);
3356 cond_evaluator->ResetVisitStates();
3357 loop_body_evaluator->ResetVisitStates();
3358 }
3359 }
3360 evaluated_[while_hlo] = std::move(lcv);
3361 return OkStatus();
3362 }
3363
3364 namespace {
3365 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64_t const> indices,bool extract_as_scalar)3366 Literal ExtractLiteralFromIndexPositions(const Literal& from,
3367 absl::Span<int64_t const> indices,
3368 bool extract_as_scalar) {
3369 if (extract_as_scalar) {
3370 return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
3371 }
3372 // We use a InlinedVector here because we need to convert it to an
3373 // absl::Span later, and this would not work with std::vector<bool>.
3374 absl::InlinedVector<NativeT, 10> values;
3375 for (int64_t index : indices) {
3376 values.push_back(from.Get<NativeT>({index}));
3377 }
3378 return LiteralUtil::CreateR1<NativeT>(values);
3379 }
3380
ExtractFromIndexPositions(const Literal & from,absl::Span<int64_t const> indices,bool extract_as_scalar=false)3381 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
3382 absl::Span<int64_t const> indices,
3383 bool extract_as_scalar = false) {
3384 if (extract_as_scalar) {
3385 CHECK_EQ(indices.size(), 1);
3386 }
3387 PrimitiveType type = from.shape().element_type();
3388 switch (type) {
3389 case PRED: {
3390 return ExtractLiteralFromIndexPositions<bool>(from, indices,
3391 extract_as_scalar);
3392 }
3393 case U8: {
3394 return ExtractLiteralFromIndexPositions<uint8_t>(from, indices,
3395 extract_as_scalar);
3396 }
3397 case S8: {
3398 return ExtractLiteralFromIndexPositions<int8_t>(from, indices,
3399 extract_as_scalar);
3400 }
3401 case BF16: {
3402 return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
3403 extract_as_scalar);
3404 }
3405 case F16: {
3406 return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
3407 extract_as_scalar);
3408 }
3409 case U16: {
3410 return ExtractLiteralFromIndexPositions<uint16_t>(from, indices,
3411 extract_as_scalar);
3412 }
3413 case S16: {
3414 return ExtractLiteralFromIndexPositions<int16_t>(from, indices,
3415 extract_as_scalar);
3416 }
3417 case F32: {
3418 return ExtractLiteralFromIndexPositions<float>(from, indices,
3419 extract_as_scalar);
3420 }
3421 case U32: {
3422 return ExtractLiteralFromIndexPositions<uint32_t>(from, indices,
3423 extract_as_scalar);
3424 }
3425 case S32: {
3426 return ExtractLiteralFromIndexPositions<int32_t>(from, indices,
3427 extract_as_scalar);
3428 }
3429 case F64: {
3430 return ExtractLiteralFromIndexPositions<double>(from, indices,
3431 extract_as_scalar);
3432 }
3433 case C64: {
3434 return ExtractLiteralFromIndexPositions<std::complex<float>>(
3435 from, indices, extract_as_scalar);
3436 }
3437 case U64: {
3438 return ExtractLiteralFromIndexPositions<uint64_t>(from, indices,
3439 extract_as_scalar);
3440 }
3441 case S64: {
3442 return ExtractLiteralFromIndexPositions<int64_t>(from, indices,
3443 extract_as_scalar);
3444 }
3445 case C128: {
3446 return ExtractLiteralFromIndexPositions<std::complex<double>>(
3447 from, indices, extract_as_scalar);
3448 }
3449 default:
3450 return InvalidArgument("Unsupported type for Sort: %s",
3451 PrimitiveType_Name(type));
3452 }
3453 }
3454 } // namespace
3455
HandleSort(HloInstruction * sort)3456 Status HloEvaluator::HandleSort(HloInstruction* sort) {
3457 TF_RET_CHECK(sort->operand_count() >= 1)
3458 << "Expected at least 1 operand for sort";
3459 for (int64_t i = 1; i < sort->operand_count(); ++i) {
3460 TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
3461 sort->operand(i)->shape()))
3462 << "All Sort operands must have the same dimensions";
3463 }
3464
3465 if (VLOG_IS_ON(3)) {
3466 for (int64_t i = 0; i < sort->operand_count(); ++i) {
3467 VLOG(3) << "HandleSort operand " << i << " literal: "
3468 << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
3469 }
3470 }
3471 Shape key_shape = sort->operand(0)->shape();
3472 auto rank = key_shape.rank();
3473 std::vector<Literal> result_literals;
3474 result_literals.reserve(sort->operand_count());
3475 for (int64_t i = 0; i < sort->operand_count(); ++i) {
3476 result_literals.emplace_back(sort->operand(i)->shape());
3477 }
3478 std::vector<int64_t> zero_base(rank, 0);
3479 std::vector<int64_t> increment(rank, 1);
3480 int64_t sort_dim = sort->dimensions(0);
3481 int64_t sort_dim_elements = key_shape.dimensions(sort_dim);
3482 TF_RET_CHECK(sort_dim >= 0 && sort_dim < increment.size())
3483 << "Unexpected out-of-bound sort dimension " << sort_dim
3484 << " accessing increment of size " << increment.size();
3485 increment[sort_dim] = sort_dim_elements;
3486 std::unique_ptr<HloEvaluator> embedded_evaluator =
3487 CreateEmbedded(max_loop_iterations_);
3488 // Iterate through each dimension except 'sort_dim'.
3489 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3490 key_shape, zero_base, key_shape.dimensions(), increment,
3491 [&](absl::Span<const int64_t> indices) -> StatusOr<bool> {
3492 // Extract a slice from each operand literal that corresponds to
3493 // exactly the row in dimension 'sort_dim'.
3494 std::vector<int64_t> limit_indices(indices.begin(), indices.end());
3495 absl::c_for_each(limit_indices, [](int64_t& index) { ++index; });
3496 limit_indices[sort_dim] = sort_dim_elements;
3497 std::vector<Literal> literals_to_sort;
3498 literals_to_sort.reserve(sort->operand_count());
3499 for (int64_t i = 0; i < sort->operand_count(); ++i) {
3500 TF_ASSIGN_OR_RETURN(auto literal_to_sort,
3501 GetEvaluatedLiteralFor(sort->operand(i))
3502 .Slice(indices, limit_indices)
3503 .Reshape({sort_dim_elements}));
3504 literals_to_sort.push_back(std::move(literal_to_sort));
3505 }
3506 std::vector<int64_t> indices_to_sort(sort_dim_elements);
3507 std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
3508 Status compare_status = OkStatus();
3509 auto comparator = [sort, &compare_status,
3510 embedded_evaluator = embedded_evaluator.get(),
3511 &literals_to_sort](int64_t a, int64_t b) {
3512 std::vector<Literal> literals;
3513 literals.reserve(2 * sort->operand_count());
3514 for (int64_t i = 0; i < sort->operand_count(); ++i) {
3515 auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
3516 /*extract_as_scalar=*/true);
3517 if (!lhs.ok()) {
3518 compare_status = lhs.status();
3519 return false;
3520 }
3521 literals.push_back(std::move(lhs.ValueOrDie()));
3522 auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
3523 /*extract_as_scalar=*/true);
3524 if (!rhs.ok()) {
3525 compare_status = rhs.status();
3526 return false;
3527 }
3528 literals.push_back(std::move(rhs.ValueOrDie()));
3529 }
3530 std::vector<const Literal*> literal_ptrs;
3531 absl::c_transform(literals, std::back_inserter(literal_ptrs),
3532 [](const Literal& literal) { return &literal; });
3533
3534 auto computed_result =
3535 embedded_evaluator->Evaluate(*sort->to_apply(), literal_ptrs);
3536 // Clear visit states so that we can use the evaluator again
3537 // on the same computation.
3538 embedded_evaluator->ResetVisitStates();
3539 if (!computed_result.ok()) {
3540 compare_status = computed_result.status();
3541 return false;
3542 }
3543 return computed_result.ValueOrDie().Get<bool>({});
3544 };
3545 if (Cast<HloSortInstruction>(sort)->is_stable()) {
3546 std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
3547 comparator);
3548 } else {
3549 std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
3550 }
3551 if (!compare_status.ok()) {
3552 return compare_status;
3553 }
3554 std::vector<int64_t> slice_dimensions(rank, 1);
3555 slice_dimensions[sort_dim] = sort_dim_elements;
3556 std::vector<int64_t> start_indices(rank, 0);
3557 for (int64_t i = 0; i < sort->operand_count(); ++i) {
3558 TF_ASSIGN_OR_RETURN(
3559 Literal sorted_literal,
3560 ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
3561 TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
3562 sorted_literal.Reshape(slice_dimensions));
3563 TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
3564 sorted_literal_reshaped, start_indices, indices,
3565 slice_dimensions));
3566 }
3567 return true;
3568 }));
3569
3570 if (sort->operand_count() == 1) {
3571 evaluated_[sort] = std::move(result_literals[0]);
3572 } else {
3573 std::vector<const Literal*> literal_ptrs;
3574 absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
3575 [](const Literal& literal) { return &literal; });
3576
3577 Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
3578 VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
3579
3580 evaluated_[sort] = std::move(result_tuple);
3581 }
3582 return OkStatus();
3583 }
3584
IsScalarAdd(HloComputation * computation)3585 static bool IsScalarAdd(HloComputation* computation) {
3586 HloInstruction* instruction = computation->root_instruction();
3587 if (instruction->opcode() == HloOpcode::kAdd &&
3588 computation->num_parameters() == 2) {
3589 const HloInstruction* lhs = instruction->operand(0);
3590 const HloInstruction* rhs = instruction->operand(1);
3591 return lhs->opcode() == HloOpcode::kParameter &&
3592 ShapeUtil::IsScalar(lhs->shape()) &&
3593 rhs->opcode() == HloOpcode::kParameter &&
3594 ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
3595 }
3596 return false;
3597 }
3598
3599 // Run a single step of an inner loop while running reduction, which applies
3600 // the user-provided computation on the accumulator and the output element
3601 // (until the reduction is completed, the output element is also used as
3602 // an accumulator).
PerformReductionStep(bool is_tuple,absl::Span<const int64_t> input_index,absl::Span<const int64_t> output_index,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * computation,HloEvaluator * embedded_evaluator)3603 static StatusOr<bool> PerformReductionStep(
3604 bool is_tuple, absl::Span<const int64_t> input_index,
3605 absl::Span<const int64_t> output_index,
3606 absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
3607 HloComputation* computation, HloEvaluator* embedded_evaluator) {
3608 int num_args = results.size();
3609
3610 absl::InlinedVector<Literal, 1> arg_values;
3611 arg_values.reserve(num_args);
3612 absl::InlinedVector<Literal, 1> accumulators;
3613 accumulators.reserve(num_args);
3614 for (int64_t i = 0; i < num_args; ++i) {
3615 arg_values.emplace_back(
3616 ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
3617 accumulators.emplace_back(
3618 ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
3619
3620 TF_RETURN_IF_ERROR(
3621 arg_values[i].CopyElementFrom(*input_args[i], input_index, {}));
3622 TF_RETURN_IF_ERROR(
3623 accumulators[i].CopyElementFrom(results[i], output_index, {}));
3624 }
3625
3626 // Evaluate computation with specified literal operands.
3627 absl::InlinedVector<Literal*, 2> embedded_operands;
3628 for (Literal& accumulator : accumulators) {
3629 embedded_operands.push_back(&accumulator);
3630 }
3631 for (Literal& local_input : arg_values) {
3632 embedded_operands.push_back(&local_input);
3633 }
3634
3635 TF_ASSIGN_OR_RETURN(
3636 Literal computed_result,
3637 embedded_evaluator->Evaluate(*computation, embedded_operands));
3638
3639 // Clear visit states so that we can use the evaluator again on the same
3640 // computation.
3641 embedded_evaluator->ResetVisitStates();
3642
3643 if (is_tuple) {
3644 std::vector<Literal> computed_results = computed_result.DecomposeTuple();
3645 for (int64_t i = 0; i < num_args; ++i) {
3646 TF_RETURN_IF_ERROR(
3647 results[i].CopyElementFrom(computed_results[i], {}, output_index));
3648 }
3649 } else {
3650 TF_RETURN_IF_ERROR(
3651 results[0].CopyElementFrom(computed_result, {}, output_index));
3652 }
3653
3654 return true;
3655 }
3656
GenerateReduceOutputElement(bool is_tuple,absl::Span<const int64_t> output_index,absl::Span<const Literal * const> init_values,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * function,HloEvaluator * embedded_evaluator,absl::Span<const int64_t> arg_dim_steps,absl::Span<const int64_t> arg_dim_counts,absl::Span<const int64_t> result_to_arg_index)3657 static StatusOr<bool> GenerateReduceOutputElement(
3658 bool is_tuple, absl::Span<const int64_t> output_index,
3659
3660 absl::Span<const Literal* const> init_values,
3661 absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
3662
3663 HloComputation* function, HloEvaluator* embedded_evaluator,
3664
3665 absl::Span<const int64_t> arg_dim_steps,
3666 absl::Span<const int64_t> arg_dim_counts,
3667 absl::Span<const int64_t> result_to_arg_index) {
3668 bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) &&
3669 IsScalarAdd(function) && !is_tuple;
3670
3671 const Shape& arg_shape = input_args[0]->shape();
3672 absl::Span<const int64_t> arg_dimensions = arg_shape.dimensions();
3673 std::vector<int64_t> base(arg_dimensions.size());
3674 for (int64_t i = 0; i < output_index.size(); ++i) {
3675 base[result_to_arg_index[i]] = output_index[i];
3676 }
3677
3678 for (int64_t i = 0; i < results.size(); ++i) {
3679 TF_RETURN_IF_ERROR(
3680 results[i].CopyElementFrom(*init_values[i], {}, output_index));
3681 }
3682
3683 if (use_fast_add) {
3684 double computed_result = *init_values[0]->GetAsDouble({});
3685 auto reduction_step =
3686 [&](absl::Span<const int64_t> input_index) -> StatusOr<bool> {
3687 double argument = *input_args[0]->GetAsDouble(input_index);
3688 computed_result += argument;
3689 return true;
3690 };
3691 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3692 arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step));
3693 TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result));
3694 return true;
3695 }
3696
3697 // Iterates only over reduced shape, as counts and steps are set to zero
3698 // for all non-reduced dimensions.
3699 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3700 arg_shape, base, arg_dim_counts, arg_dim_steps,
3701 [&](absl::Span<const int64_t> input_index) {
3702 return PerformReductionStep(is_tuple, input_index, output_index,
3703 input_args, results, function,
3704 embedded_evaluator);
3705 }));
3706 return true;
3707 }
3708
HandleReduce(HloInstruction * instr)3709 Status HloEvaluator::HandleReduce(HloInstruction* instr) {
3710 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
3711 int64_t num_args = reduce->inputs().size();
3712 absl::Span<const int64_t> dimensions_to_reduce(reduce->dimensions());
3713 HloComputation* function = reduce->to_apply();
3714
3715 absl::InlinedVector<const Shape*, 1> operand_shapes;
3716 for (const HloInstruction* operand : reduce->operands()) {
3717 operand_shapes.push_back(&operand->shape());
3718 }
3719 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
3720 ShapeInference::InferReduceShape(
3721 operand_shapes, dimensions_to_reduce,
3722 /*to_apply=*/function->ComputeProgramShape()));
3723 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(),
3724 inferred_return_shape))
3725 << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
3726 << " but is inferred to be: "
3727 << ShapeUtil::HumanString(inferred_return_shape);
3728
3729 absl::InlinedVector<const Literal*, 1> input_args(num_args);
3730 absl::InlinedVector<const Literal*, 1> init_values(num_args);
3731 for (int64_t i = 0; i < num_args; ++i) {
3732 input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]);
3733 VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString();
3734 init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]);
3735 VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString();
3736 TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape()));
3737 }
3738
3739 // All args and results have the same dimensions, so pick an arbitrary one.
3740 const Shape& arg_shape = input_args[0]->shape();
3741 const Shape& out_shape = inferred_return_shape;
3742 bool is_tuple = out_shape.IsTuple();
3743 const Shape& output_shape = inferred_return_shape.IsTuple()
3744 ? inferred_return_shape.tuple_shapes(0)
3745 : inferred_return_shape;
3746
3747 absl::Span<const int64_t> arg_dimensions = arg_shape.dimensions();
3748
3749 // All increments are set to 0.
3750 std::vector<int64_t> arg_dim_steps(arg_dimensions.size());
3751
3752 // All counts are set to 0.
3753 std::vector<int64_t> arg_dim_counts(arg_dimensions.size());
3754
3755 // Set steps and counts for reduced dimensions.
3756 // This avoids iterating over non-reduced dimensions, as their step
3757 // and count is set to zero.
3758 for (const int64_t dim : dimensions_to_reduce) {
3759 arg_dim_steps[dim] = 1;
3760 arg_dim_counts[dim] = arg_dimensions[dim];
3761 }
3762
3763 // Map each dimension in the result to a dimension in arg that isn't
3764 // being reduced.
3765 std::vector<int64_t> result_to_arg_index;
3766 for (int64_t i = 0; i < arg_dimensions.size(); ++i) {
3767 if (arg_dim_steps[i] == 0) {
3768 result_to_arg_index.push_back(i);
3769 }
3770 }
3771
3772 std::unique_ptr<HloEvaluator> embedded_evaluator =
3773 CreateEmbedded(max_loop_iterations_);
3774 absl::InlinedVector<Literal, 1> results(num_args);
3775 for (int64_t i = 0; i < num_args; ++i) {
3776 results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape);
3777 }
3778
3779 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
3780 output_shape, [&](absl::Span<const int64_t> output_index) {
3781 return GenerateReduceOutputElement(
3782 is_tuple, output_index, init_values, input_args,
3783 absl::Span<Literal>(results), function, embedded_evaluator.get(),
3784 arg_dim_steps, arg_dim_counts, result_to_arg_index);
3785 }));
3786
3787 if (is_tuple) {
3788 Literal tuple_result(inferred_return_shape);
3789 for (int64_t i = 0; i < num_args; ++i) {
3790 TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
3791 }
3792 evaluated_[reduce] = std::move(tuple_result);
3793 } else {
3794 CHECK_EQ(results.size(), 1);
3795 evaluated_[reduce] = std::move(results[0]);
3796 }
3797 if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) {
3798 TF_ASSIGN_OR_RETURN(evaluated_[reduce],
3799 evaluated_[reduce].ConvertToShape(reduce->shape()));
3800 }
3801 return OkStatus();
3802 }
3803
HandleReduceWindow(HloInstruction * hlo)3804 Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
3805 // Here we delegate the handling to the typed visitor class, instantiated by
3806 // using the type of the first input of ReduceWindow. The support for the
3807 // variadic case inside the typed_visitor is made to not use the template
3808 // parameter so it doesn't really matter which type is used to instantiate it
3809 // here. We choose not to move the implementation for handle ReduceWindow
3810 // from the typed visitor to here because we need to reuse the
3811 // IterateThroughWindow method, which is defined and only avaiable inside the
3812 // typed visitor.
3813 if (hlo->shape().IsTuple()) {
3814 return hlo->Visit(
3815 typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
3816 } else {
3817 return DefaultAction(hlo);
3818 }
3819 }
3820
HandleCustomCall(HloInstruction * custom_call)3821 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
3822 if (!custom_call_handler_) {
3823 // No handler is registered; this means custom-calls are not allowed.
3824 return DefaultAction(custom_call);
3825 }
3826
3827 // Evaluate input operands so the handler has access to the operand data.
3828 std::vector<const Literal*> operands;
3829 operands.reserve(custom_call->operand_count());
3830 for (const HloInstruction* operand : custom_call->operands()) {
3831 operands.push_back(&GetEvaluatedLiteralFor(operand));
3832 }
3833
3834 // Synchronously issue the handler to populate the instruction output literal.
3835 TF_ASSIGN_OR_RETURN(
3836 auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
3837
3838 evaluated_[custom_call] = std::move(output);
3839 return OkStatus();
3840 }
3841
Preprocess(HloInstruction * hlo)3842 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
3843 VLOG(3) << "About to visit HLO: " << hlo->ToString();
3844 if (!enable_partial_evaluation_) {
3845 for (HloInstruction* operand : hlo->mutable_operands()) {
3846 if (!IsAlreadyEvaluated(operand) ||
3847 !GetEvaluatedLiteralFor(operand).IsKnown()) {
3848 return tensorflow::errors::FailedPrecondition(
3849 "Failed to evaluate instruction since its operands are unknown "
3850 "or undetermined and partial evaluation is not enabled.");
3851 }
3852 }
3853 }
3854 return ShapeUtil::ValidateShape(hlo->shape());
3855 }
3856
Postprocess(HloInstruction * hlo)3857 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
3858 VLOG(3) << "Finished visiting " << hlo->ToString()
3859 << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
3860 // Out of convenience the literal may have been produced with a different
3861 // layout. Relayout as indicated by the HLO instruction.
3862 auto evaluated_shape = GetEvaluatedLiteralFor(hlo).shape();
3863 xla::Shape hlo_shape = hlo->shape();
3864 if (hlo_shape.IsArray() && !hlo_shape.has_layout()) {
3865 *hlo_shape.mutable_layout() =
3866 LayoutUtil::GetDefaultLayoutForShape(hlo_shape);
3867 }
3868 if (evaluated_shape.has_layout() && hlo_shape.has_layout() &&
3869 !Layout::Equal().MinorToMajorOnly()(evaluated_shape.layout(),
3870 hlo_shape.layout())) {
3871 evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo_shape);
3872 }
3873 return OkStatus();
3874 }
3875
3876 namespace {
3877 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)> & impl_fn)3878 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
3879 const Array2D<T>& lhs, const Array2D<T>& rhs,
3880 const std::function<void(const void* run_options_ptr, T* out, T* lhs,
3881 T* rhs, int64_t m, int64_t n, int64_t k,
3882 int32_t transpose_lhs, int32_t transpose_rhs)>&
3883 impl_fn) {
3884 CHECK_EQ(lhs.width(), rhs.height());
3885 int m = lhs.height();
3886 int n = rhs.width();
3887 int k = lhs.width();
3888 auto result = std::make_unique<Array2D<T>>(m, n);
3889 // Because Eigen is a header-oriented library, make sure that the Eigen code
3890 // is the same as the code used by the CPU backend (otherwise the linker will
3891 // randomly pick *some* definition).
3892 impl_fn(
3893 /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
3894 k,
3895 /*transpose_lhs=*/0,
3896 /*transpose_rhs=*/0);
3897 return result;
3898 }
3899 } // namespace
3900
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)3901 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
3902 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
3903 return MatmulArray2DImpl<Eigen::half>(
3904 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
3905 }
3906
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)3907 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
3908 const Array2D<float>& lhs, const Array2D<float>& rhs) {
3909 return MatmulArray2DImpl<float>(
3910 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
3911 }
3912
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)3913 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
3914 const Array2D<double>& lhs, const Array2D<double>& rhs) {
3915 return MatmulArray2DImpl<double>(
3916 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
3917 }
3918
MatmulArray2D(const Array2D<std::complex<float>> & lhs,const Array2D<std::complex<float>> & rhs)3919 std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
3920 const Array2D<std::complex<float>>& lhs,
3921 const Array2D<std::complex<float>>& rhs) {
3922 return MatmulArray2DImpl<std::complex<float>>(
3923 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
3924 }
3925
MatmulArray2D(const Array2D<std::complex<double>> & lhs,const Array2D<std::complex<double>> & rhs)3926 std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
3927 const Array2D<std::complex<double>>& lhs,
3928 const Array2D<std::complex<double>>& rhs) {
3929 return MatmulArray2DImpl<std::complex<double>>(
3930 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
3931 }
3932
MatmulArray2D(const Array2D<int32_t> & lhs,const Array2D<int32_t> & rhs)3933 std::unique_ptr<Array2D<int32_t>> HloEvaluator::MatmulArray2D(
3934 const Array2D<int32_t>& lhs, const Array2D<int32_t>& rhs) {
3935 return MatmulArray2DImpl<int32_t>(
3936 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulS32);
3937 }
3938
3939 } // namespace xla
3940