1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <complex>
20 #include <cstdlib>
21 #include <functional>
22 #include <iterator>
23 #include <string>
24 #include <type_traits>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/index_util.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/hlo_query.h"
43 #include "tensorflow/compiler/xla/service/shape_inference.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/types.h"
55
56 namespace xla {
57
58 namespace {
59
60 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)61 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
62 LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
63 std::function<bool(OperandT, OperandT)> compare_op;
64 switch (direction) {
65 case ComparisonDirection::kEq:
66 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
67 return lhs_el == rhs_el;
68 };
69 break;
70 case ComparisonDirection::kNe:
71 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
72 return lhs_el != rhs_el;
73 };
74 break;
75 case ComparisonDirection::kGe:
76 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
77 return lhs_el >= rhs_el;
78 };
79 break;
80 case ComparisonDirection::kGt:
81 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
82 return lhs_el > rhs_el;
83 };
84 break;
85 case ComparisonDirection::kLe:
86 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
87 return lhs_el <= rhs_el;
88 };
89 break;
90 case ComparisonDirection::kLt:
91 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
92 return lhs_el < rhs_el;
93 };
94 break;
95 }
96
97 Literal result(shape);
98 TF_RETURN_IF_ERROR(
99 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
100 return compare_op(lhs_literal.Get<OperandT>(multi_index),
101 rhs_literal.Get<OperandT>(multi_index));
102 }));
103
104 return std::move(result);
105 }
106
107 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)108 StatusOr<Literal> Compare<complex64>(const Shape& shape,
109 ComparisonDirection direction,
110 LiteralSlice lhs_literal,
111 LiteralSlice rhs_literal) {
112 std::function<bool(complex64, complex64)> compare_op;
113 switch (direction) {
114 case ComparisonDirection::kEq:
115 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
116 return lhs_el == rhs_el;
117 };
118 break;
119 case ComparisonDirection::kNe:
120 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
121 return lhs_el != rhs_el;
122 };
123 break;
124 default:
125 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
126 << ComparisonDirectionToString(direction);
127 }
128
129 Literal result(shape);
130 TF_RETURN_IF_ERROR(
131 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
132 return compare_op(lhs_literal.Get<complex64>(multi_index),
133 rhs_literal.Get<complex64>(multi_index));
134 }));
135
136 return std::move(result);
137 }
138
139 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)140 StatusOr<Literal> Compare<complex128>(const Shape& shape,
141 ComparisonDirection direction,
142 LiteralSlice lhs_literal,
143 LiteralSlice rhs_literal) {
144 std::function<bool(complex128, complex128)> compare_op;
145 switch (direction) {
146 case ComparisonDirection::kEq:
147 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
148 return lhs_el == rhs_el;
149 };
150 break;
151 case ComparisonDirection::kNe:
152 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
153 return lhs_el != rhs_el;
154 };
155 break;
156 default:
157 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
158 << ComparisonDirectionToString(direction);
159 }
160
161 Literal result(shape);
162 TF_RETURN_IF_ERROR(
163 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
164 return compare_op(lhs_literal.Get<complex128>(multi_index),
165 rhs_literal.Get<complex128>(multi_index));
166 }));
167
168 return std::move(result);
169 }
170
171 } // namespace
172
173 // Note that unsupported types by the typed visitor does not necessarily imply
174 // the non-typed HloEvaluator (parent evaluator) would not support them either
175 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
176 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
177 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64 max_loop_iterations)178 HloEvaluator::HloEvaluator(int64 max_loop_iterations)
179 : max_loop_iterations_(max_loop_iterations) {
180 typed_visitors_[PRED] =
181 absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
182 typed_visitors_[U8] =
183 absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
184 typed_visitors_[U16] =
185 absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
186 typed_visitors_[U32] =
187 absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
188 typed_visitors_[U64] =
189 absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
190 typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
191 typed_visitors_[S16] =
192 absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
193 typed_visitors_[S32] =
194 absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
195 typed_visitors_[S64] =
196 absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
197 typed_visitors_[F16] =
198 absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
199 typed_visitors_[F32] =
200 absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
201 typed_visitors_[F64] =
202 absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
203 typed_visitors_[C64] =
204 absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
205 typed_visitors_[C128] =
206 absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
207
208 // Most of the evaluator computations we use don't support BF16 (e.g.,
209 // std::ceil, std::tanh). To make evaluator work with BF16, we set all
210 // elementwise computations to be done in F32 and do BF16<->F32 conversion
211 // around the input and the output of the computations.
212 typed_visitors_[BF16] =
213 absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
214
215 typed_visitors_[TUPLE] =
216 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
217 return Unimplemented(
218 "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
219 });
220 typed_visitors_[OPAQUE_TYPE] =
221 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
222 return Unimplemented(
223 "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE.");
224 });
225 typed_visitors_[TOKEN] =
226 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
227 return Unimplemented(
228 "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
229 });
230 }
231
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)232 StatusOr<Literal> HloEvaluator::Evaluate(
233 const HloComputation& computation,
234 absl::Span<const Literal* const> arg_literals) {
235 CHECK(computation.parent() != nullptr);
236 XLA_VLOG_LINES(
237 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
238
239 if (arg_literals.size() != computation.num_parameters()) {
240 return InvalidArgument(
241 "Expected %d argument%s, but got %d.", computation.num_parameters(),
242 computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
243 }
244 for (int64 i = 0; i < arg_literals.size(); ++i) {
245 const auto& computation_shape =
246 computation.parameter_instruction(i)->shape();
247 const auto& arg_shape = arg_literals[i]->shape();
248 if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
249 arg_shape)) {
250 return InvalidArgument(
251 "Shape mismatch at parameter %d. Computation expected %s, but arg "
252 "was %s.",
253 i, ShapeUtil::HumanStringWithLayout(computation_shape),
254 ShapeUtil::HumanStringWithLayout(arg_shape));
255 }
256 }
257
258 evaluated_.clear();
259 arg_literals_.clear();
260 for (const auto& literal_ptr : arg_literals) {
261 arg_literals_.push_back(&*literal_ptr);
262 }
263
264 // Re-seed RNG, either from the configuration's seed or a monotonic
265 // per-evaluator seed (which prevents two evaluators from returning the same
266 // random sequence).
267 if (computation.parent()->config().seed()) {
268 seed_ = computation.parent()->config().seed();
269 } else {
270 // Start global_seed at a (true) random value.
271 static std::atomic<uint64> global_seed{std::random_device()()};
272 seed_ = global_seed.fetch_add(1);
273 }
274 engine_.seed(seed_);
275
276 TF_RETURN_IF_ERROR(computation.Accept(this));
277
278 if (VLOG_IS_ON(100)) {
279 for (const HloInstruction* instr : computation.instructions()) {
280 VLOG(100) << instr->name() << " = " << GetEvaluatedLiteralFor(instr);
281 }
282 }
283
284 return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
285 }
286
Evaluate(HloInstruction * instruction)287 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
288 // If the instruction is a kCopyDone, simply find the argument that it is
289 // copied from.
290 while (instruction->opcode() == HloOpcode::kCopyDone) {
291 if (instruction->operand(0)->opcode() != HloOpcode::kCopyStart) {
292 return tensorflow::errors::FailedPrecondition(
293 "kCopyDone has an argument different than a kCopyStart.");
294 }
295 instruction = instruction->mutable_operand(0)->mutable_operand(0);
296 }
297 if (instruction->opcode() == HloOpcode::kParameter) {
298 return tensorflow::errors::FailedPrecondition(
299 "Cannot evaluate a parameter.");
300 }
301 if (!hlo_query::AllOperandsAreConstants(*instruction)) {
302 return tensorflow::errors::FailedPrecondition(
303 "Not all operands are constants.");
304 }
305
306 arg_literals_.clear();
307 evaluated_.clear();
308
309 TF_RETURN_IF_ERROR(Preprocess(instruction));
310 TF_RETURN_IF_ERROR(instruction->Visit(this));
311 TF_RETURN_IF_ERROR(Postprocess(instruction));
312 return GetEvaluatedLiteralFor(instruction).Clone();
313 }
314
TryEvaluate(HloInstruction * instruction,Literal * result)315 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
316 CHECK(result != nullptr);
317 auto result_or = Evaluate(instruction);
318 if (!result_or.ok()) {
319 VLOG(1) << "TryEvaluate failed:" << result_or.status();
320 return false;
321 }
322
323 *result = result_or.ConsumeValueOrDie();
324 return true;
325 }
326
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)327 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
328 const HloInstruction* instruction,
329 const std::unordered_map<const HloInstruction*, const Literal*>&
330 substitutions) {
331 std::vector<std::unique_ptr<HloInstruction>> owned_operands;
332 for (const HloInstruction* operand : instruction->operands()) {
333 auto it = substitutions.find(operand);
334 if (it == substitutions.end()) {
335 owned_operands.push_back(operand->Clone());
336 } else {
337 owned_operands.push_back(
338 HloInstruction::CreateConstant(it->second->Clone()));
339 }
340 }
341
342 std::vector<HloInstruction*> operands;
343 operands.reserve(owned_operands.size());
344 for (auto& operand : owned_operands) {
345 operands.push_back(operand.get());
346 }
347
348 std::unique_ptr<HloInstruction> cloned_instruction =
349 instruction->CloneWithNewOperands(instruction->shape(), operands);
350 auto result = Evaluate(cloned_instruction.get());
351
352 return result;
353 }
354
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)355 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
356 HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
357 std::unique_ptr<HloInstruction> lhs_instr =
358 HloInstruction::CreateConstant(lhs.Clone());
359 std::unique_ptr<HloInstruction> rhs_instr =
360 HloInstruction::CreateConstant(rhs.Clone());
361
362 std::unique_ptr<HloInstruction> cloned_instruction =
363 HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
364 rhs_instr.get());
365 auto result = Evaluate(cloned_instruction.get());
366
367 return result;
368 }
369
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)370 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
371 HloOpcode opcode, const Literal& operand) {
372 std::unique_ptr<HloInstruction> operand_instr =
373 HloInstruction::CreateConstant(operand.Clone());
374
375 std::unique_ptr<HloInstruction> cloned_instruction =
376 HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
377 auto result = Evaluate(cloned_instruction.get());
378
379 return result;
380 }
381
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)382 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
383 const DotDimensionNumbers& dim_numbers,
384 const PrecisionConfig& precision_config, const Literal& lhs,
385 const Literal& rhs) {
386 std::unique_ptr<HloInstruction> lhs_instr =
387 HloInstruction::CreateConstant(lhs.Clone());
388 std::unique_ptr<HloInstruction> rhs_instr =
389 HloInstruction::CreateConstant(rhs.Clone());
390
391 TF_ASSIGN_OR_RETURN(Shape dot_shape,
392 ShapeInference::InferDotOpShape(
393 lhs.shape(), rhs.shape(), dim_numbers,
394 /*preferred_element_type=*/absl::nullopt));
395
396 std::unique_ptr<HloInstruction> cloned_instruction =
397 HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
398 dim_numbers, precision_config);
399 return Evaluate(cloned_instruction.get());
400 }
401
HandleBitcast(HloInstruction * bitcast)402 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
403 const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
404 Literal result(bitcast->shape());
405 TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
406 memcpy(result.untyped_data(), operand_literal.untyped_data(),
407 operand_literal.size_bytes());
408 evaluated_[bitcast] = std::move(result);
409 return Status::OK();
410 }
411
HandleGetDimensionSize(HloInstruction * get_dimension_size)412 Status HloEvaluator::HandleGetDimensionSize(
413 HloInstruction* get_dimension_size) {
414 HloInstruction* operand = get_dimension_size->mutable_operand(0);
415 int64 dim = get_dimension_size->dimension();
416 if (dynamic_dimension_inference_ == nullptr) {
417 return InvalidArgument(
418 "Evaluator cannot evaluate get_dimension_size without "
419 "set_dynamic_dimension_inference.");
420 }
421 HloInstruction* dynamic_size =
422 dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
423 if (dynamic_size != nullptr) {
424 evaluated_[get_dimension_size] =
425 GetEvaluatedLiteralFor(dynamic_size).Clone();
426 return Status::OK();
427 }
428
429 const Shape& shape = get_dimension_size->operand(0)->shape();
430 Literal output(ShapeUtil::MakeShape(S32, {}));
431 output.PopulateWithValue(
432 static_cast<int32>(shape.dimensions(get_dimension_size->dimension())));
433 evaluated_[get_dimension_size] = std::move(output);
434 return Status::OK();
435 }
436
HandleSetDimensionSize(HloInstruction * set_dimension_size)437 Status HloEvaluator::HandleSetDimensionSize(
438 HloInstruction* set_dimension_size) {
439 const Literal& operand_literal =
440 GetEvaluatedLiteralFor(set_dimension_size->operand(0));
441 Literal result(set_dimension_size->shape());
442 memcpy(result.untyped_data(), operand_literal.untyped_data(),
443 operand_literal.size_bytes());
444 const Literal& size_literal =
445 GetEvaluatedLiteralFor(set_dimension_size->operand(1));
446 result.SetDynamicSize(set_dimension_size->dimension(),
447 size_literal.Get<int32>({}));
448 evaluated_[set_dimension_size] = std::move(result);
449 return Status::OK();
450 }
451
HandleParameter(HloInstruction * parameter)452 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
453 // Nothing to do other than sanity checks. Parameters' values are stored in
454 // arg_literals_.
455 CHECK_LT(parameter->parameter_number(), arg_literals_.size());
456
457 #ifndef NDEBUG
458 const Literal* input_literal = arg_literals_[parameter->parameter_number()];
459 VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
460 DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
461 input_literal->shape()))
462 << "parameter shape is: "
463 << ShapeUtil::HumanStringWithLayout(parameter->shape())
464 << ", but input literal shape is: "
465 << ShapeUtil::HumanStringWithLayout(input_literal->shape());
466 #endif
467
468 return Status::OK();
469 }
470
HandleConstant(HloInstruction *)471 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
472
HandleReshape(HloInstruction * reshape)473 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
474 TF_ASSIGN_OR_RETURN(
475 evaluated_[reshape],
476 GetEvaluatedLiteralFor(reshape->operand(0))
477 .Reshape(AsInt64Slice(reshape->shape().dimensions())));
478 return Status::OK();
479 }
480
HandleTranspose(HloInstruction * transpose)481 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
482 evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
483 .Transpose(transpose->dimensions());
484 return Status::OK();
485 }
486
HandleConcatenate(HloInstruction * concatenate)487 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
488 absl::Span<HloInstruction* const> operands(concatenate->operands());
489 // The result concatenate dimension is going to be the sum of all
490 // concatenate dimensions of the operands taking part of the operation.
491 const Shape& reference_shape = operands[0]->shape();
492 CHECK(reference_shape.IsArray());
493 const int64 rank = reference_shape.rank();
494 const int64 concat_dim = concatenate->dimensions()[0];
495 CHECK_GE(concat_dim, 0);
496 CHECK_LT(concat_dim, rank);
497
498 DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
499 reference_shape.dimensions().end());
500
501 for (int64 i = 1; i < operands.size(); ++i) {
502 const Shape& operand_shape = operands[i]->shape();
503 CHECK(operand_shape.IsArray());
504 // Accumulate the concat dimension from all tensors taking part to the
505 // operation.
506 concat_dimensions[concat_dim] +=
507 ShapeUtil::GetDimension(operand_shape, concat_dim);
508 }
509
510 auto result_literal = LiteralUtil::CreateFromDimensions(
511 reference_shape.element_type(), concat_dimensions);
512 DimensionVector source_indices(rank, 0);
513 DimensionVector dest_indices(concat_dimensions.size(), 0);
514
515 for (auto operand : operands) {
516 const Shape& operand_shape = operand->shape();
517 TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
518 GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
519 AsInt64Slice(operand_shape.dimensions())));
520 dest_indices[concat_dim] +=
521 ShapeUtil::GetDimension(operand_shape, concat_dim);
522 }
523
524 evaluated_[concatenate] = std::move(result_literal);
525 return Status::OK();
526 }
527
HandleIsFinite(HloInstruction * is_finite)528 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
529 auto operand = is_finite->operand(0);
530 auto elem_ty = operand->shape().element_type();
531 switch (elem_ty) {
532 case PRED:
533 case TUPLE:
534 case OPAQUE_TYPE:
535 case TOKEN:
536 case S8:
537 case S16:
538 case S32:
539 case S64:
540 case U8:
541 case U16:
542 case U32:
543 case U64:
544 case C64:
545 case C128:
546 // Explicitly enumerate all types in this switch so that when we add a new
547 // type, we'll get a compile error here.
548 case PRIMITIVE_TYPE_INVALID:
549 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
550 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
551 return InvalidArgument(
552 "expected element type in shape to be floating point, but "
553 "got: %s",
554 PrimitiveType_Name(elem_ty));
555
556 case F16: {
557 auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
558 is_finite,
559 [](Eigen::half elem_operand) {
560 return std::isfinite(static_cast<float>(elem_operand));
561 },
562 GetEvaluatedLiteralFor(operand));
563 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
564 break;
565 }
566 case BF16: {
567 auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
568 is_finite,
569 [](bfloat16 elem_operand) {
570 return std::isfinite(static_cast<float>(elem_operand));
571 },
572 GetEvaluatedLiteralFor(operand));
573 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
574 break;
575 }
576 case F32: {
577 auto result_or = ElementWiseUnaryOpImpl<bool, float>(
578 is_finite,
579 [](float elem_operand) { return std::isfinite(elem_operand); },
580 GetEvaluatedLiteralFor(operand));
581 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
582 break;
583 }
584 case F64: {
585 auto result_or = ElementWiseUnaryOpImpl<bool, double>(
586 is_finite,
587 [](double elem_operand) { return std::isfinite(elem_operand); },
588 GetEvaluatedLiteralFor(operand));
589 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
590 break;
591 }
592 }
593
594 return Status::OK();
595 }
596
HandleReal(HloInstruction * real)597 Status HloEvaluator::HandleReal(HloInstruction* real) {
598 auto operand = real->operand(0);
599 switch (operand->shape().element_type()) {
600 case BF16: {
601 auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
602 real, [](bfloat16 elem_operand) { return elem_operand; },
603 GetEvaluatedLiteralFor(operand));
604 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
605 break;
606 }
607 case C64: {
608 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
609 real, [](complex64 elem_operand) { return std::real(elem_operand); },
610 GetEvaluatedLiteralFor(operand));
611 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
612 break;
613 }
614 case C128: {
615 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
616 real, [](complex128 elem_operand) { return std::real(elem_operand); },
617 GetEvaluatedLiteralFor(operand));
618 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
619 break;
620 }
621 case F16: {
622 auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
623 real, [](Eigen::half elem_operand) { return elem_operand; },
624 GetEvaluatedLiteralFor(operand));
625 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
626 break;
627 }
628 case F32: {
629 auto result_or = ElementWiseUnaryOpImpl<float, float>(
630 real, [](float elem_operand) { return elem_operand; },
631 GetEvaluatedLiteralFor(operand));
632 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
633 break;
634 }
635 case F64: {
636 auto result_or = ElementWiseUnaryOpImpl<double, double>(
637 real, [](double elem_operand) { return elem_operand; },
638 GetEvaluatedLiteralFor(operand));
639 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
640 break;
641 }
642 default:
643 LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
644 << PrimitiveType_Name(operand->shape().element_type());
645 }
646
647 return Status::OK();
648 }
649
HandleImag(HloInstruction * imag)650 Status HloEvaluator::HandleImag(HloInstruction* imag) {
651 auto operand = imag->operand(0);
652 switch (operand->shape().element_type()) {
653 case C64: {
654 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
655 imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
656 GetEvaluatedLiteralFor(imag->operand(0)));
657
658 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
659 break;
660 }
661 case C128: {
662 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
663 imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
664 GetEvaluatedLiteralFor(imag->operand(0)));
665
666 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
667 break;
668 }
669 default:
670 LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
671 << PrimitiveType_Name(operand->shape().element_type());
672 }
673
674 return Status::OK();
675 }
676
HandleComplex(HloInstruction * complex)677 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
678 const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
679 const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
680 TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
681
682 Literal result(complex->shape());
683 switch (complex->shape().element_type()) {
684 case C64: {
685 TF_RETURN_IF_ERROR(
686 result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
687 return std::complex<float>(real.Get<float>(multi_index),
688 imag.Get<float>(multi_index));
689 }));
690 break;
691 }
692 case C128: {
693 TF_RETURN_IF_ERROR(
694 result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
695 return std::complex<double>(real.Get<double>(multi_index),
696 imag.Get<double>(multi_index));
697 }));
698 break;
699 }
700 default:
701 LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
702 << PrimitiveType_Name(complex->shape().element_type());
703 }
704
705 evaluated_[complex] = std::move(result);
706 return Status::OK();
707 }
708
HandleCompare(HloInstruction * compare)709 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
710 ComparisonDirection direction = compare->comparison_direction();
711 auto lhs = compare->operand(0);
712 auto rhs = compare->operand(1);
713 DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
714 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
715
716 TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
717
718 const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
719 const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
720
721 // Note here we switch on the operand's type.
722 switch (lhs->shape().element_type()) {
723 case PRED: {
724 TF_ASSIGN_OR_RETURN(
725 evaluated_[compare],
726 Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
727 } break;
728 case U8: {
729 TF_ASSIGN_OR_RETURN(evaluated_[compare],
730 Compare<uint8>(compare->shape(), direction,
731 lhs_literal, rhs_literal));
732 } break;
733 case U16: {
734 TF_ASSIGN_OR_RETURN(evaluated_[compare],
735 Compare<uint16>(compare->shape(), direction,
736 lhs_literal, rhs_literal));
737 } break;
738 case U32: {
739 TF_ASSIGN_OR_RETURN(evaluated_[compare],
740 Compare<uint32>(compare->shape(), direction,
741 lhs_literal, rhs_literal));
742 } break;
743 case U64: {
744 TF_ASSIGN_OR_RETURN(evaluated_[compare],
745 Compare<uint64>(compare->shape(), direction,
746 lhs_literal, rhs_literal));
747 } break;
748 case S8: {
749 TF_ASSIGN_OR_RETURN(
750 evaluated_[compare],
751 Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
752 } break;
753 case S16: {
754 TF_ASSIGN_OR_RETURN(evaluated_[compare],
755 Compare<int16>(compare->shape(), direction,
756 lhs_literal, rhs_literal));
757 } break;
758 case S32: {
759 TF_ASSIGN_OR_RETURN(evaluated_[compare],
760 Compare<int32>(compare->shape(), direction,
761 lhs_literal, rhs_literal));
762 } break;
763 case S64: {
764 TF_ASSIGN_OR_RETURN(evaluated_[compare],
765 Compare<int64>(compare->shape(), direction,
766 lhs_literal, rhs_literal));
767 } break;
768 case F16: {
769 TF_ASSIGN_OR_RETURN(
770 evaluated_[compare],
771 Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
772 } break;
773 case BF16: {
774 TF_ASSIGN_OR_RETURN(evaluated_[compare],
775 Compare<bfloat16>(compare->shape(), direction,
776 lhs_literal, rhs_literal));
777 } break;
778 case F32: {
779 TF_ASSIGN_OR_RETURN(evaluated_[compare],
780 Compare<float>(compare->shape(), direction,
781 lhs_literal, rhs_literal));
782 } break;
783 case F64: {
784 TF_ASSIGN_OR_RETURN(evaluated_[compare],
785 Compare<double>(compare->shape(), direction,
786 lhs_literal, rhs_literal));
787 } break;
788 case C64: {
789 TF_ASSIGN_OR_RETURN(evaluated_[compare],
790 Compare<complex64>(compare->shape(), direction,
791 lhs_literal, rhs_literal));
792 } break;
793 case C128: {
794 TF_ASSIGN_OR_RETURN(evaluated_[compare],
795 Compare<complex128>(compare->shape(), direction,
796 lhs_literal, rhs_literal));
797 } break;
798 default:
799 LOG(FATAL) << "HandleCompare: unknown primitive type: "
800 << PrimitiveType_Name(lhs->shape().element_type());
801 }
802
803 return Status::OK();
804 }
805
HandleTuple(HloInstruction * tuple)806 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
807 std::vector<const Literal*> operand_literals;
808 for (auto operand : tuple->operands()) {
809 operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
810 }
811
812 evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
813 return Status::OK();
814 }
815
816 namespace {
817
818 // Common code used by 1D implementations, which copies data from the input to
819 // the contiguous buffer. Returns true if all copied values are zero.
GatherToBuffer(absl::Span<complex128> data,int64 length,int64 start,int64 stride,bool expand_input,absl::Span<complex128> buffer)820 bool GatherToBuffer(absl::Span<complex128> data, int64 length, int64 start,
821 int64 stride, bool expand_input,
822 absl::Span<complex128> buffer) {
823 CHECK_GE(buffer.size(), length);
824 bool input_is_zero = true;
825 const int64 ub = expand_input ? length / 2 + 1 : length;
826 CHECK_GE(data.size(), start + (ub - 1) * stride);
827 for (int64 k = 0; k < ub; k++) {
828 complex128 value = data[start + k * stride];
829 input_is_zero &= value == complex128(0.0, 0.0);
830 buffer[k] = value;
831 if (expand_input) {
832 // Use conjugates of the values at indices [1 ... (ub - 2)] when the
833 // length is even and at indices [1 ... (ub - 1)] when the length is odd
834 // to calculate missing values at indices [(length - 1) ... ub].
835 if (k > 0 && k < (length - ub + 1)) {
836 buffer[length - k] = std::conj(value);
837 }
838 }
839 }
840 return input_is_zero;
841 }
842
843 // Returns (conjugated, if 'inverse' is true) k-th twiddle for the given length.
Twiddle(int64 k,int64 length,bool inverse)844 inline complex128 Twiddle(int64 k, int64 length, bool inverse) {
845 auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * k / length));
846 return inverse ? std::conj(coeff) : coeff;
847 }
848
849 // Straightforward implementation of 1D DFT transform of arbitrary length. Uses
850 // passed-in start index and stride to gather inputs from the data vector into
851 // the preallocated buffer, computes the result, and writes it back to the same
852 // locations in the data vector. Runs in O(length^2) time.
853 //
854 // Parameters contract_output and expand_input are used to avoid unnecessary
855 // calculations. When contract_output is set to true, then only (length / 2) + 1
856 // output values are computed. When expand_input is set to true, then
857 // (length / 2) + 1 values from the data set are used to re-create the full set
858 // of size 'length', on which the transform is then performed.
859 //
NaiveDft1D(int64 length,int64 start,int64 stride,bool inverse,bool contract_output,bool expand_input,absl::Span<complex128> data,absl::Span<complex128> buffer)860 void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse,
861 bool contract_output, bool expand_input,
862 absl::Span<complex128> data, absl::Span<complex128> buffer) {
863 const bool input_is_zero =
864 GatherToBuffer(data, length, start, stride, expand_input, buffer);
865
866 if (!input_is_zero) {
867 const int64 ub = contract_output ? length / 2 + 1 : length;
868 for (int64 k = 0; k < ub; k++) {
869 complex128 value = complex128(0.0, 0.0);
870 for (int n = 0; n < length; n++) {
871 value += buffer[n] * Twiddle(n * k, length, inverse);
872 }
873 data[start + k * stride] =
874 inverse ? value / complex128(length, 0.0) : value;
875 }
876 }
877 }
878
879 // Non-recursive implementation of the Cooley-Tukey radix-2 decimation in time.
880 // Performs 1D FFT transform for the lengths, which are powers of 2. Runs in
881 // O(length * log(length)) time. Uses the same parameters as the naive
882 // implementation above, except that the preallocated buffer must be at least
883 // twice as big as the length of the transform, because the buffer is used to
884 // hold both input and output values for each stage of the transform.
885 //
Fft1D(int64 length,int64 start,int64 stride,bool inverse,bool contract_output,bool expand_input,absl::Span<complex128> data,absl::Span<complex128> buffer)886 void Fft1D(int64 length, int64 start, int64 stride, bool inverse,
887 bool contract_output, bool expand_input, absl::Span<complex128> data,
888 absl::Span<complex128> buffer) {
889 CHECK(IsPowerOfTwo(static_cast<uint64>(length)));
890 const bool input_is_zero =
891 GatherToBuffer(data, length, start, stride, expand_input, buffer);
892
893 if (!input_is_zero) {
894 auto generate_twiddles = [](int64 length, bool inverse) {
895 std::vector<complex128> twiddles;
896 // Need only half the twiddles.
897 for (int64 k = 0; k < length / 2; k++) {
898 twiddles.push_back(Twiddle(k, length, inverse));
899 }
900 return twiddles;
901 };
902
903 // Indices into the parts of the buffer used for input and output values.
904 int64 in_base = length;
905 int64 out_base = 0;
906
907 // At each stage, we "split" the input data into num_blocks, with block_size
908 // values in each block.
909 for (int64 num_blocks = 1; num_blocks < length; num_blocks *= 2) {
910 // Swap input and output parts of the buffer.
911 std::swap(in_base, out_base);
912 auto twiddles = generate_twiddles(num_blocks * 2, inverse);
913 const int64 block_size = length / num_blocks;
914 const int64 next_iteration_block_size = block_size / 2;
915 for (int64 block = 0; block < num_blocks; block++) {
916 const int64 in_offset = in_base + block * block_size;
917 const int64 out_offset = out_base + block * next_iteration_block_size;
918 // For each (even, odd) pair of values in the block, calculate two
919 // output values as even + twiddle * odd and even - twiddle * odd.
920 for (int64 pair = 0; pair < block_size / 2; pair++) {
921 const complex128 even = buffer[in_offset + pair];
922 const complex128 odd = buffer[in_offset + block_size / 2 + pair];
923 const complex128 twiddled_odd = twiddles[block] * odd;
924 buffer[out_offset + pair] = even + twiddled_odd;
925 buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
926 }
927 }
928 }
929 // Copy computed result back to data.
930 const int64 ub = contract_output ? length / 2 + 1 : length;
931 for (int64 k = 0; k < ub; k++) {
932 complex128 value = buffer[out_base + k];
933 data[start + k * stride] =
934 inverse ? value / complex128(length, 0.0) : value;
935 }
936 }
937 }
938
939 // Determine, which implementation of 1D transform to use and call it.
Dft1D(int64 length,int64 start,int64 stride,bool inverse,bool contract_output,bool expand_input,absl::Span<complex128> data,absl::Span<complex128> buffer)940 void Dft1D(int64 length, int64 start, int64 stride, bool inverse,
941 bool contract_output, bool expand_input, absl::Span<complex128> data,
942 absl::Span<complex128> buffer) {
943 if (IsPowerOfTwo(static_cast<uint64>(length))) {
944 Fft1D(length, start, stride, inverse, contract_output, expand_input, data,
945 buffer);
946 } else {
947 NaiveDft1D(length, start, stride, inverse, contract_output, expand_input,
948 data, buffer);
949 }
950 }
951
952 // Helper to reverse the order of dimension lengths in the passed-in literal.
GetDimensionLengths(const Literal & literal)953 std::vector<int64> GetDimensionLengths(const Literal& literal) {
954 auto dimensions = literal.shape().dimensions();
955 return std::vector<int64>(dimensions.rbegin(), dimensions.rend());
956 }
957
958 // Helper to compute strides for creating linear indices into multidimensional
959 // data from the dimension lengths and the layout. Returns a new vector of size
960 // lengths.size() + 1. The last element of the returned vector at index
961 // [lengths.size()] contains the product of all dimension lengths.
ComputeStrides(const absl::Span<const int64> lengths,const Layout & layout)962 std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
963 const Layout& layout) {
964 const int64 num_dimensions = lengths.size();
965
966 // Make sure that the layout length matches the number of dimensions.
967 CHECK_EQ(num_dimensions, layout.minor_to_major_size());
968
969 // Calculate strides using layout-specified ordering of the dimensions and
970 // place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
971 std::vector<int64> strides(num_dimensions + 1);
972 int64 stride = 1;
973 for (int64 i = 0; i < num_dimensions; i++) {
974 // Reverse the ordering of the dimensions in the layout.
975 const int64 index = (num_dimensions - 1) - layout.minor_to_major(i);
976 strides[index] = stride;
977 stride *= lengths[index];
978 }
979 strides[num_dimensions] = stride;
980
981 return strides;
982 }
983
984 // Compute strides as above using the default layout.
ComputeStrides(const absl::Span<const int64> lengths)985 std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths) {
986 return ComputeStrides(lengths,
987 LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
988 }
989
990 // Compute strides as above using the layout from the literal, if available.
ComputeStrides(const absl::Span<const int64> lengths,const Literal & literal)991 std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
992 const Literal& literal) {
993 return literal.shape().has_layout()
994 ? ComputeStrides(lengths, literal.shape().layout())
995 : ComputeStrides(lengths);
996 }
997
998 // Make 1D sweeps along each transform axis.
Sweep(int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,absl::Span<complex128> data,absl::Span<complex128> buffer)999 void Sweep(int64 fft_rank, FftType fft_type,
1000 const absl::Span<const int64> fft_lengths,
1001 const absl::Span<const int64> fft_strides,
1002 absl::Span<complex128> data, absl::Span<complex128> buffer) {
1003 const bool inverse = fft_type == FftType::IFFT || fft_type == FftType::IRFFT;
1004 const bool input_is_truncated = fft_type == FftType::IRFFT;
1005 const bool output_is_truncated = fft_type == FftType::RFFT;
1006
1007 // Recursively visit each column of the data along the sweep_axis. Calculate
1008 // linearized index of that column's first element and the stride, then invoke
1009 // 1D transform.
1010 // For RFFT, avoid calculating unused output values: first, compute only
1011 // (length_x / 2) + 1 values along the X axis, then limit the X coordinate to
1012 // [0 ... (length / 2)] during the sweeps along other axes. Similarly, for
1013 // IRFFT sweep along higher dimensions first, while keeping the X coordinate
1014 // in the [0 ... (length / 2)] range, then re-create negative frequencies
1015 // omitted in the input and perform the full-length transform along the X axis
1016 // in the last sweep.
1017 std::function<void(int64, int64, int64)> sweep = [&](int64 sweep_axis,
1018 int64 axis,
1019 int64 start) {
1020 if (axis < 0) {
1021 // Base case: invoke 1D transform.
1022 const int64 length = fft_lengths[sweep_axis];
1023 const int64 stride = fft_strides[sweep_axis];
1024 const bool expand_input = input_is_truncated && sweep_axis == 0;
1025 const bool contract_oputput = output_is_truncated && sweep_axis == 0;
1026 Dft1D(length, start, stride, inverse, contract_oputput, expand_input,
1027 data, buffer);
1028 } else if (axis == sweep_axis) {
1029 // Visit only the elements with coordinate 0 along the sweep axis.
1030 sweep(sweep_axis, axis - 1, start);
1031 } else {
1032 const int64 length = fft_lengths[axis];
1033 const bool is_truncated = input_is_truncated || output_is_truncated;
1034 const int64 ub = is_truncated && axis == 0 ? (length / 2) + 1 : length;
1035 for (int64 i = 0; i < ub; i++) {
1036 sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
1037 }
1038 }
1039 };
1040 if (input_is_truncated) {
1041 // Sweep along the X axis last for IRFFT.
1042 for (int64 sweep_axis = fft_rank - 1; sweep_axis >= 0; sweep_axis--) {
1043 sweep(sweep_axis, fft_rank - 1, 0);
1044 }
1045 } else {
1046 // Sweep along the X axis first for RFFT. The order does not matter for FFT
1047 // and IFFT types; handle them here as well.
1048 for (int64 sweep_axis = 0; sweep_axis < fft_rank; sweep_axis++) {
1049 sweep(sweep_axis, fft_rank - 1, 0);
1050 }
1051 }
1052 }
1053
1054 // These templates convert the data from the input data type to the type used in
1055 // calculations and then to the output data type. They are intended to be used
1056 // only within the DFT implementation. One special case is IRFFT, where the
1057 // specialization drops imaginary parts of complex values (which is expected to
1058 // be 0) and returns real numbers.
1059 template <typename ToType, typename FromType>
GetAs(FromType value)1060 ToType GetAs(FromType value) {
1061 return static_cast<ToType>(value);
1062 }
1063
1064 template <>
GetAs(complex128 value)1065 float GetAs<float, complex128>(complex128 value) {
1066 return static_cast<float>(value.real());
1067 }
1068
1069 // This template generates two linearized indices, which can be used to access
1070 // multidimensional arrays. It uses a recursive function, which passes the
1071 // indices to the user-supplied callback function. The destination index is
1072 // always within dst_lengths[] bounds. The boolean parameter within_src_bounds
1073 // indicates whether the source index is within src_lengths[] bounds.
1074 //
1075 // The value returned from the callback function controls the recursion depth.
1076 // Returning true indicates that the base case had been hit and the recursion
1077 // stops. Otherwise, the recursion proceeds along the next less-major axis.
1078 //
1079 // For example, the base case when the axis value becomes negative invokes the
1080 // callback function for each possible index within dst_lengths[] bounds. The
1081 // base case when the axis value is equal to zero limits the indices to point
1082 // only to first elements along the minor-most dimension, allowing the callback
1083 // function to handle all values along the X axis.
1084 //
1085 template <typename BaseFn>
GenerateIndices(const absl::Span<const int64> dst_lengths,const absl::Span<const int64> dst_strides,const absl::Span<const int64> src_lengths,const absl::Span<const int64> src_strides,int64 fft_rank,int64 dst_start,int64 src_start,BaseFn && base)1086 void GenerateIndices(const absl::Span<const int64> dst_lengths,
1087 const absl::Span<const int64> dst_strides,
1088 const absl::Span<const int64> src_lengths,
1089 const absl::Span<const int64> src_strides, int64 fft_rank,
1090 int64 dst_start, int64 src_start, BaseFn&& base) {
1091 CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
1092 CHECK_GE(dst_lengths.size(), fft_rank);
1093 CHECK_EQ(src_lengths.size() + 1, src_strides.size());
1094 CHECK_GE(src_lengths.size(), fft_rank);
1095
1096 std::function<void(int64, int64, int64, bool)> generate =
1097 [&](int64 axis, int64 dst_index, int64 src_index,
1098 bool within_src_bounds) {
1099 if (!base(axis, dst_index, src_index, within_src_bounds)) {
1100 for (int64 i = 0; i < dst_lengths[axis]; i++) {
1101 // Because the loop goes over dst_lengths[], the source index may be
1102 // out of src_lengths[] bounds. In this case, within_src_bounds is
1103 // false.
1104 within_src_bounds &= i < src_lengths[axis];
1105 generate(axis - 1, dst_index, src_index, within_src_bounds);
1106 dst_index += dst_strides[axis];
1107 src_index += src_strides[axis];
1108 }
1109 }
1110 };
1111 generate(fft_rank - 1, dst_start, src_start, true);
1112 }
1113
1114 // Copies the input data from a literal to a pre-allocated vector. The sizes of
1115 // the input and the transform do not need to match. For each axis of the
1116 // transform, any extra input values beyond the transform length are ignored.
1117 // Conversely, if the input does not contain enough elements along any axis, the
1118 // data is padded with zeroes.
1119 //
1120 // For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
1121 // where length_x is the size of the full transform along the X axis.
1122 //
1123 // The input literal may have a rank higher than the rank of the transform.
1124 // Passed-in input_index value points to the first element of the input literal
1125 // to be copied.
1126 //
1127 // Returns true if all values in the work data set are zeroes.
1128 //
1129 template <typename InputType>
CopyDataFromInput(const Literal & input_literal,int64 input_start,int64 fft_rank,FftType fft_type,int64 fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<complex128> data)1130 bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
1131 int64 fft_rank, FftType fft_type, int64 fft_size,
1132 const absl::Span<const int64> fft_lengths,
1133 const absl::Span<const int64> fft_strides,
1134 const absl::Span<const int64> input_lengths,
1135 const absl::Span<const int64> input_strides,
1136 absl::Span<complex128> data) {
1137 CHECK_GE(data.size(), fft_size);
1138
1139 const bool input_is_truncated = fft_type == FftType::IRFFT;
1140
1141 // Recursively visit each transform dimension to copy input values to the
1142 // working data set. The base case handles inputs along the X axis.
1143 bool input_is_zero = true;
1144 const InputType* input_data = input_literal.data<InputType>().data();
1145 auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
1146 bool within_src_bounds) {
1147 if (axis == 0) {
1148 // For IRFFT, the negative frequencies are only needed for the sweep along
1149 // the X axis, which is performed last. Leave this part of the working set
1150 // uninitialized until then.
1151 const int64 length = fft_lengths[axis];
1152 const int64 ub = input_is_truncated ? (length / 2) + 1 : length;
1153 for (int64 i = 0; i < ub; i++) {
1154 complex128 value = InputType(0);
1155 // Read input value only if the index is within bounds.
1156 if (within_src_bounds && i < input_lengths[axis]) {
1157 value = GetAs<complex128, InputType>(
1158 input_data[src_index + i * input_strides[axis]]);
1159 input_is_zero &= value == complex128(0.0, 0.0);
1160 }
1161 data[dst_index + i * fft_strides[axis]] = value;
1162 }
1163 return true;
1164 }
1165 return false;
1166 };
1167 GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
1168 fft_rank, 0, input_start, base_case);
1169 return input_is_zero;
1170 }
1171
1172 // Copies the result of the transform to the literal output. The sizes of the
1173 // transform and output must match.
1174 //
1175 // For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
1176 // the size of the full transform along the X axis (the most minor dimension).
1177 //
1178 // The output literal may have a rank higher than the rank of the transform.
1179 // Passed-in output_index value points to the first element of the output
1180 // literal to be filled in.
1181 //
1182 template <typename OutputType>
CopyDataToOutput(const absl::Span<complex128> data,int64 output_start,int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1183 void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
1184 int64 fft_rank, FftType fft_type,
1185 const absl::Span<const int64> fft_lengths,
1186 const absl::Span<const int64> fft_strides,
1187 const absl::Span<const int64> output_lengths,
1188 const absl::Span<const int64> output_strides,
1189 Literal* output_literal) {
1190 const bool output_is_truncated = fft_type == FftType::RFFT;
1191
1192 // Base case for recursive copy of the results to the output. The code avoids
1193 // making a recursive call for each output element by handling axis 0 in the
1194 // loop (as opposed to making "axis < 0" to be the base case).
1195 OutputType* output_data = output_literal->data<OutputType>().data();
1196 auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
1197 bool within_src_bounds) {
1198 if (axis == 0) {
1199 // Drop negative frequencies for RFFT.
1200 const int64 length = fft_lengths[axis];
1201 const int64 ub = output_is_truncated ? (length / 2) + 1 : length;
1202 for (int64 i = 0; i < output_lengths[axis]; i++) {
1203 OutputType value = OutputType(0);
1204 // Read data only if the index is within bounds.
1205 if (within_src_bounds && i < ub) {
1206 value = GetAs<OutputType, complex128>(
1207 data[src_index + i * fft_strides[axis]]);
1208 }
1209 output_data[dst_index + i * output_strides[axis]] = value;
1210 }
1211 return true;
1212 }
1213 return false;
1214 };
1215 GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
1216 fft_rank, output_start, 0, base_case);
1217 }
1218
1219 // Determine the type to use with the CopyDataFromInput<> template above.
CopyDataFromInput(const Literal & input_literal,int64 input_start,int64 fft_rank,FftType fft_type,int64 fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<complex128> data)1220 bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
1221 int64 fft_rank, FftType fft_type, int64 fft_size,
1222 const absl::Span<const int64> fft_lengths,
1223 const absl::Span<const int64> fft_strides,
1224 const absl::Span<const int64> input_lengths,
1225 const absl::Span<const int64> input_strides,
1226 absl::Span<complex128> data) {
1227 const bool input_is_float = fft_type == FftType::RFFT;
1228 if (input_is_float) {
1229 return CopyDataFromInput<float>(
1230 input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
1231 fft_strides, input_lengths, input_strides, data);
1232 } else {
1233 return CopyDataFromInput<complex64>(
1234 input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
1235 fft_strides, input_lengths, input_strides, data);
1236 }
1237 }
1238
1239 // Determine the type to use with the CopyDataToOutput<> template above.
CopyDataToOutput(const absl::Span<complex128> data,int64 output_start,int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1240 void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
1241 int64 fft_rank, FftType fft_type,
1242 const absl::Span<const int64> fft_lengths,
1243 const absl::Span<const int64> fft_strides,
1244 const absl::Span<const int64> output_lengths,
1245 const absl::Span<const int64> output_strides,
1246 Literal* output_literal) {
1247 const bool output_is_float = fft_type == FftType::IRFFT;
1248 if (output_is_float) {
1249 CopyDataToOutput<float>(data, output_start, fft_rank, fft_type, fft_lengths,
1250 fft_strides, output_lengths, output_strides,
1251 output_literal);
1252 } else {
1253 CopyDataToOutput<complex64>(data, output_start, fft_rank, fft_type,
1254 fft_lengths, fft_strides, output_lengths,
1255 output_strides, output_literal);
1256 }
1257 }
1258
CheckParameters(const Shape & input_shape,const Shape & output_shape,int64 fft_rank,FftType fft_type,const absl::Span<const int64> fft_lengths)1259 Status CheckParameters(const Shape& input_shape, const Shape& output_shape,
1260 int64 fft_rank, FftType fft_type,
1261 const absl::Span<const int64> fft_lengths) {
1262 // Check FFT parameters.
1263 if (fft_rank <= 0) {
1264 return InvalidArgument("Zero or negative FFT rank.");
1265 }
1266 if (*absl::c_min_element(fft_lengths) < 0) {
1267 return InvalidArgument("Negative FFT length.");
1268 }
1269
1270 // Check input-related values.
1271 TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
1272 if (!input_shape.IsArray()) {
1273 return Unimplemented("Only array input shapes are supported.");
1274 }
1275 auto input_elt_type = input_shape.element_type();
1276 if (fft_type == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
1277 return InvalidArgument("Invalid input type: %d, must be %d (float).",
1278 input_elt_type, PrimitiveType::F32);
1279 }
1280 if (fft_type != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
1281 return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
1282 input_elt_type, PrimitiveType::C64);
1283 }
1284 const int64 input_rank = input_shape.rank();
1285 if (input_rank < fft_rank) {
1286 return InvalidArgument("Input shape rank is smaller than FFT rank.");
1287 }
1288
1289 // Check output-related values.
1290 TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
1291 if (!output_shape.IsArray()) {
1292 return Unimplemented("Only array output shapes are supported.");
1293 }
1294 auto output_elt_type = output_shape.element_type();
1295 if (fft_type == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
1296 return InvalidArgument("Invalid output type: %d, must be %d (float).",
1297 output_elt_type, PrimitiveType::F32);
1298 }
1299 if (fft_type != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
1300 return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
1301 output_elt_type, PrimitiveType::C64);
1302 }
1303 const int64 output_rank = output_shape.rank();
1304 if (output_rank < fft_rank) {
1305 return InvalidArgument("Output shape rank is smaller than FFT rank.");
1306 }
1307
1308 // Consistency of input and output parameters.
1309 if (input_rank != output_rank) {
1310 return InvalidArgument(
1311 "Ranks of input shape and output shape do not match.");
1312 }
1313 for (int64 dim = 0; dim < input_rank - fft_rank; dim++) {
1314 if (ShapeUtil::GetDimension(input_shape, dim) !=
1315 ShapeUtil::GetDimension(output_shape, dim)) {
1316 return InvalidArgument(
1317 "Higher dimension lengths of input shape and output shape do not "
1318 "match.");
1319 }
1320 }
1321
1322 return Status::OK();
1323 }
1324
1325 } // namespace
1326
1327 // Flexible implementation of the discrete Fourier transform. All transform
1328 // types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary
1329 // rank and length of each dimension of the transform, and arbitrary layouts of
1330 // the input and output literals.
1331 //
1332 // The input literal in operand 0 provides input data, which must be complex64
1333 // for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed
1334 // over the innermost dimensions of the input, thus the rank of the input data
1335 // must be same as fft_rank or larger. The input is expected to provide Ni
1336 // values along each transform axis with one exception: for IRFFT, only
1337 // (N0 / 2) + 1 values are needed along the X axis (the innermost index). To
1338 // increase flexibility, this implementation can handle mismatches between the
1339 // input size and transform lengths by either dropping extra input values or
1340 // using zeroes in place of missing input values as necessary. If the input data
1341 // has rank higher than the transform, the transform is applied for each valid
1342 // combination of the higher-ranking indices.
1343 //
1344 // The output contains complex64 values for FFT, IFFT, RFFT, and float values
1345 // for IRFFT. The rank of the output as well as the sizes of the dimensions
1346 // above the rank of the transform must match those of the input. Sizes of the
1347 // output's "fft_rank" innermost dimensions are expected to match the length of
1348 // the transform along respective axes with one exception: for RFFT, the output
1349 // is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
1350 // length(s) mismatch, the FFT output is trimmed to fit into the provided output
1351 // shape, or the output is padded with zero values appropriately.
1352 //
1353 // For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
1354 // input array will perform two transforms over the [][15][17] data in the sub
1355 // arrays [0][][] and [1][][], dropping the values along axis X and padding axis
1356 // Y with zeroes to create 16x16 working sets, and generating
1357 // complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
1358 // complex64[64][16][9] input array will use all input values and will produce
1359 // float[64][16][16] output.
1360 //
1361 // The implementation of the 1D transform for lengths, that are powers of 2, is
1362 // the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform
1363 // lengths, a straightforward, but slow, loop nest is used. The transforms of
1364 // higher ranks apply sets of 1D transforms along each axis. For example, the 2D
1365 // transform is computed by applying 1D transforms to each column followed by
1366 // applying 1D transforms to each row.
1367 //
1368 // In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
1369 // time, where Ni is the length of the transform's i-th dimension. However, for
1370 // dimension lengths, which are powers of 2, the run time along these dimensions
1371 // is reduced to log(Ni) in the summation, giving the runtime of
1372 // O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case.
1373 //
HandleFft(HloInstruction * fft)1374 Status HloEvaluator::HandleFft(HloInstruction* fft) {
1375 const FftType fft_type = fft->fft_type();
1376 std::vector<int64> fft_lengths = fft->fft_length();
1377 const int64 fft_rank = fft_lengths.size();
1378 const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
1379 const Shape& input_shape = input_literal.shape();
1380 const Shape& output_shape = fft->shape();
1381 Literal output_literal = Literal::CreateFromShape(output_shape);
1382
1383 // Make fft_lengths[0] the minor-most dimension.
1384 absl::c_reverse(fft_lengths);
1385
1386 TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape, fft_rank,
1387 fft_type, fft_lengths));
1388
1389 const auto fft_strides = ComputeStrides(fft_lengths);
1390
1391 // Working set size.
1392 const int64 fft_size = fft_strides[fft_rank];
1393
1394 if (fft_size > 0) {
1395 // Linearized working data set.
1396 std::vector<complex128> data(fft_size);
1397
1398 // Temporary buffer allocated once and used in 1D sweeps. For dimension
1399 // length values that are powers of 2, the buffer should be twice as large.
1400 int64 buffer_size = 0;
1401 for (auto len : fft_lengths) {
1402 int64 size = IsPowerOfTwo(static_cast<uint64>(len)) ? len * 2 : len;
1403 buffer_size = std::max(buffer_size, size);
1404 }
1405 std::vector<complex128> buffer(buffer_size);
1406
1407 // Sizes of each axis of input and output literals.
1408 const auto input_lengths = GetDimensionLengths(input_literal);
1409 const auto output_lengths = GetDimensionLengths(output_literal);
1410
1411 // Strides for generating linearized indices into multidimensional arrays.
1412 const auto input_strides = ComputeStrides(input_lengths, input_literal);
1413 const auto output_strides = ComputeStrides(output_lengths, output_literal);
1414
1415 // Visit all elements in the dimensions with ranks above the FFT rank. For
1416 // each such element invoke the transform. Use separate indices for the
1417 // input and the output to allow different layouts.
1418 auto base_case = [&](int64 axis, int64 output_index, int64 input_index,
1419 bool within_src_bounds) {
1420 if (axis == fft_rank - 1) {
1421 // Base case: copy the data from the input literal, apply the
1422 // transform, and copy the result to the output literal.
1423 CHECK(within_src_bounds);
1424 bool input_is_zero =
1425 CopyDataFromInput(input_literal, input_index, fft_rank, fft_type,
1426 fft_size, fft_lengths, fft_strides, input_lengths,
1427 input_strides, absl::MakeSpan(data));
1428 if (!input_is_zero) {
1429 // Make 1D sweeps along each transform axis.
1430 Sweep(fft_rank, fft_type, fft_lengths, fft_strides,
1431 absl::MakeSpan(data), absl::MakeSpan(buffer));
1432 }
1433 CopyDataToOutput(absl::MakeSpan(data), output_index, fft_rank, fft_type,
1434 fft_lengths, fft_strides, output_lengths,
1435 output_strides, &output_literal);
1436 return true;
1437 }
1438 return false;
1439 };
1440 GenerateIndices(output_lengths, output_strides, input_lengths,
1441 input_strides, input_shape.rank(), 0, 0, base_case);
1442 }
1443
1444 evaluated_[fft] = std::move(output_literal);
1445 return Status::OK();
1446 }
1447
1448 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
1449 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)1450 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
1451 const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
1452 int64 output_rank = output_shape.dimensions_size();
1453 std::vector<int64> index_base(output_rank, 0);
1454 std::vector<int64> index_count;
1455 index_count.reserve(output_rank);
1456 for (int64 i = 0; i < output_rank; i++) {
1457 bool is_output_batch_dim =
1458 !absl::c_binary_search(dim_numbers.offset_dims(), i);
1459 index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
1460 }
1461
1462 return {std::move(index_base), std::move(index_count),
1463 std::vector<int64>(output_rank, 1)};
1464 }
1465
1466 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
1467 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64 output_rank,absl::Span<const int64> slice_sizes,const GatherDimensionNumbers & dim_numbers)1468 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
1469 int64 output_rank, absl::Span<const int64> slice_sizes,
1470 const GatherDimensionNumbers& dim_numbers) {
1471 std::vector<int64> index_base(output_rank, 0);
1472 std::vector<int64> index_count(output_rank, 1);
1473 int64 slice_sizes_idx = 0;
1474 for (int64 i = 0; i < output_rank; i++) {
1475 bool is_output_window_dim =
1476 absl::c_binary_search(dim_numbers.offset_dims(), i);
1477 if (is_output_window_dim) {
1478 while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
1479 slice_sizes_idx)) {
1480 slice_sizes_idx++;
1481 }
1482 index_count[i] = slice_sizes[slice_sizes_idx++];
1483 }
1484 }
1485
1486 return {std::move(index_base), std::move(index_count),
1487 std::vector<int64>(output_rank, 1)};
1488 }
1489
1490 // This functor computes the contribution of start_indices to an input index
1491 // corresponding to an output index. That is, given an output index I, it picks
1492 // out the batch indices in I and uses them to look up a starting index, G, from
1493 // the start indices tensor, and expands G into the input space according to
1494 // start_index_map.
1495 class OutputBatchIndexToInputIndex {
1496 public:
1497 // The constructor does some setup work that is amortized across all
1498 // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)1499 explicit OutputBatchIndexToInputIndex(
1500 const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
1501 const Shape& output_shape, const Literal* start_indices)
1502 : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
1503 for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
1504 output_dim_is_batch_dims_.push_back(
1505 !absl::c_binary_search(dim_numbers_.offset_dims(), i));
1506 }
1507
1508 for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
1509 int64 index_of_input_dim_in_index_vector =
1510 std::distance(dim_numbers_.start_index_map().begin(),
1511 absl::c_find(dim_numbers_.start_index_map(), i));
1512 if (index_of_input_dim_in_index_vector ==
1513 dim_numbers_.start_index_map_size()) {
1514 input_dim_value_to_index_vector_.push_back(-1);
1515 } else {
1516 input_dim_value_to_index_vector_.push_back(
1517 index_of_input_dim_in_index_vector);
1518 }
1519 }
1520
1521 index_vector_index_.resize(start_indices_.shape().dimensions_size());
1522 input_index_.resize(input_shape.dimensions_size());
1523 int64 index_vector_size =
1524 start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
1525 index_vector_.resize(index_vector_size);
1526 }
1527
1528 // Returns the contribution of start_indices to the input index corresponding
1529 // to output_index. See gather_inner_loop_body.
1530 //
1531 // This is conceptually a stateless transformation from output_index to the
1532 // gather input index, but:
1533 //
1534 // - Instead of allocating memory to represent the gather input index on
1535 // every invocation we reuse the same storage for the result
1536 // (input_index_), mutating it in place.
1537 // - Instead of allocating buffers for temporary values like
1538 // index_vector_index_ and index_vector on every invocation, we reuse the
1539 // same storage for all invocations.
1540 //
1541 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1542 StatusOr<absl::Span<const int64>> operator()(
1543 absl::Span<const int64> output_index) {
1544 PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
1545 TF_RETURN_IF_ERROR(FetchIndexVector());
1546 PropagateIndexVectorToInputIndex();
1547 return absl::Span<const int64>(input_index_);
1548 }
1549
1550 private:
1551 // Propagates the batch dimensions from the output index into
1552 // index_vector_index_ by mutating index_vector_index_ in place. Does not
1553 // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
1554 // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64> output_index)1555 void PropagateOutputIndexGatherDimsToIndexVectorIndex(
1556 absl::Span<const int64> output_index) {
1557 int64 index_vector_index_i = 0;
1558 for (int64 i = 0, e = output_index.size(); i < e; i++) {
1559 if (!output_dim_is_batch_dims_[i]) {
1560 continue;
1561 }
1562
1563 if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
1564 index_vector_index_i++;
1565 }
1566
1567 index_vector_index_[index_vector_index_i++] = output_index[i];
1568 }
1569 }
1570
1571 // Populates index_vector_ by iterating over start_indices_ according to
1572 // index_vector_index_.
FetchIndexVector()1573 Status FetchIndexVector() {
1574 int64 index_vector_dim = dim_numbers_.index_vector_dim();
1575 for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
1576 index_vector_index_[index_vector_dim] = i;
1577 auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
1578 TF_RET_CHECK(start_index.has_value());
1579 index_vector_[i] = *start_index;
1580 }
1581 return Status::OK();
1582 }
1583
1584 // Populates input_index_.
PropagateIndexVectorToInputIndex()1585 void PropagateIndexVectorToInputIndex() {
1586 for (int64 i = 0, e = input_index_.size(); i < e; i++) {
1587 if (input_dim_value_to_index_vector_[i] != -1) {
1588 input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
1589 }
1590
1591 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1592 // remains 0, as set by the constructor.
1593 }
1594 }
1595
1596 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1597 // the input index from the index vector. See
1598 // PropagateIndexVectorToInputIndex.
1599 std::vector<int64> input_dim_value_to_index_vector_;
1600
1601 // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
1602 // dimension.
1603 std::vector<bool> output_dim_is_batch_dims_;
1604
1605 // The buffer into which we construct an index into start_indices_ to fetch
1606 // the index vector.
1607 std::vector<int64> index_vector_index_;
1608
1609 // The index vector fetched from start_indices_.
1610 std::vector<int64> index_vector_;
1611
1612 // The result computed by this functor. operator() returns a Span into
1613 // this vector.
1614 std::vector<int64> input_index_;
1615
1616 const GatherDimensionNumbers& dim_numbers_;
1617 const Literal& start_indices_;
1618 };
1619
1620 // This functor computes the contribution of the offset indices in an output
1621 // index to an input index. That is, given an output index I it picks out the
1622 // output offset indices in I and expands it into an index into the input shape.
1623 class OutputOffsetIndexToInputIndex {
1624 public:
1625 // The constructor does some setup work that is amortized across all
1626 // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)1627 explicit OutputOffsetIndexToInputIndex(
1628 const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
1629 const Shape& output_shape) {
1630 std::vector<int64> window_index_to_output_index;
1631 int64 output_index_count = 0;
1632 for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
1633 if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1634 window_index_to_output_index.push_back(output_index_count++);
1635 } else {
1636 output_index_count++;
1637 }
1638 }
1639
1640 int64 window_dim_count = 0;
1641 for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
1642 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1643 input_dim_value_to_output_index_.push_back(-1);
1644 } else {
1645 input_dim_value_to_output_index_.push_back(
1646 window_index_to_output_index[window_dim_count++]);
1647 }
1648 }
1649
1650 input_index_.resize(input_shape.dimensions_size());
1651 }
1652
1653 // Returns the contribution of the window indices to the input index
1654 // corresponding to output_index. See gather_inner_loop_body.
1655 //
1656 // This is conceptually a stateless transformation from output_index to the
1657 // window input index, but instead of allocating memory to represent the
1658 // gather input index on every invocation we reuse the same storage for the
1659 // result (input_index_), mutating it in place.
1660 //
1661 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1662 StatusOr<absl::Span<const int64>> operator()(
1663 absl::Span<const int64> output_index) {
1664 PropagateOutputIndexWindowDimsToInputIndex(output_index);
1665 return absl::Span<const int64>(input_index_);
1666 }
1667
1668 // Returns for a given 'input_dim' the corresponding output dimension index,
1669 // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64 input_dim)1670 int64 input_dim_value_to_output_index(int64 input_dim) {
1671 return input_dim_value_to_output_index_[input_dim];
1672 }
1673
1674 private:
1675 // Propagates window dimensions from the output index to input_index_ by
1676 // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64> output_index)1677 void PropagateOutputIndexWindowDimsToInputIndex(
1678 absl::Span<const int64> output_index) {
1679 for (int64 i = 0, e = input_index_.size(); i < e; i++) {
1680 if (input_dim_value_to_output_index_[i] != -1) {
1681 input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
1682 }
1683
1684 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1685 // remains 0, as set by the constructor.
1686 }
1687 }
1688
1689 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1690 // the input index from the output index. See
1691 // PropagateOutputIndexWindowDimsToInputIndex.
1692 std::vector<int64> input_dim_value_to_output_index_;
1693
1694 // The result computed by this functor. operator() returns a Span into
1695 // this vector.
1696 std::vector<int64> input_index_;
1697 };
1698
1699 // Reshapes the gather indices input to have a trailing degenerate `1` dimension
1700 // if necessary. Hands over the ownership of the newly created literal (if
1701 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64 index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)1702 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
1703 int64 index_vector_dim, const Literal& start_indices,
1704 Literal* reshaped_start_indices) {
1705 if (start_indices.shape().dimensions_size() != index_vector_dim) {
1706 return std::cref(start_indices);
1707 }
1708
1709 std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
1710 start_indices.shape().dimensions().end());
1711 new_shape.push_back(1);
1712 TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
1713 start_indices.Reshape(new_shape));
1714 return std::cref(*reshaped_start_indices);
1715 }
1716
HandleGather(HloInstruction * gather)1717 Status HloEvaluator::HandleGather(HloInstruction* gather) {
1718 Literal result = Literal::CreateFromShape(gather->shape());
1719 const Shape& shape = gather->shape();
1720 const GatherDimensionNumbers& dim_numbers =
1721 gather->gather_dimension_numbers();
1722 const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
1723 Literal reshaped_start_indices;
1724 TF_ASSIGN_OR_RETURN(
1725 const Literal& start_indices,
1726 ReshapedGatherIndices(dim_numbers.index_vector_dim(),
1727 GetEvaluatedLiteralFor(gather->operand(1)),
1728 &reshaped_start_indices));
1729
1730 // We iterate over the gather dimensions in the output shape in an outer loop
1731 // nest, and iterate over the window dimensions in the output shape in an
1732 // inner loop nest.
1733
1734 ShapeUtil::IndexIterationSpace start_indices_iteration_space =
1735 IterationSpaceForOutputBatchIndices(shape, dim_numbers);
1736 ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
1737 IterationSpaceForOutputOffsetIndices(
1738 shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
1739
1740 // Scratch buffers that hold an index in the output shape and the
1741 // corresponding index in the input shape.
1742 std::vector<int64> input_index(operand.shape().dimensions_size());
1743 std::vector<int64> output_index(gather->shape().dimensions_size());
1744 std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
1745
1746 OutputBatchIndexToInputIndex output_batch_index_to_input_index(
1747 &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1748 /*output_shape=*/shape, &start_indices);
1749 OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
1750 gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1751 /*output_shape=*/shape);
1752
1753 const Shape& operand_shape = operand.shape();
1754 if (ShapeUtil::IsZeroElementArray(operand_shape)) {
1755 evaluated_[gather] = std::move(result);
1756 return Status::OK();
1757 }
1758
1759 auto gather_inner_loop_body =
1760 [&](absl::Span<const int64> output_window_index,
1761 absl::Span<const int64> input_gather_index,
1762 absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1763 TF_ASSIGN_OR_RETURN(
1764 absl::Span<const int64> input_window_index,
1765 output_offset_index_to_input_index(output_window_index));
1766 for (int i = 0, e = output_index.size(); i < e; i++) {
1767 output_index[i] = output_gather_index[i] + output_window_index[i];
1768 DCHECK_LT(output_index[i], shape.dimensions(i));
1769 }
1770 for (int i = 0, e = input_gather_index.size(); i < e; i++) {
1771 int64 output_dim =
1772 output_offset_index_to_input_index.input_dim_value_to_output_index(i);
1773 // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
1774 // we set the iteration index to 0, so for the purpose of the following
1775 // calculations we can consider the output dimension size to be 1.
1776 int64 output_dim_size =
1777 output_dim == -1 ? 1 : shape.dimensions(output_dim);
1778 // Clamp the gather index so that the gather region fits in the operand.
1779 // input_index_clamped[i] = clamp(input_gather_index[i], 0,
1780 // operand_shape.dimensions(i) -
1781 // output_dim_size);
1782 input_index_clamped[i] =
1783 std::min(operand_shape.dimensions(i) - output_dim_size,
1784 std::max(int64{0}, input_gather_index[i]));
1785 }
1786 for (int i = 0, e = input_index.size(); i < e; i++) {
1787 input_index[i] = input_index_clamped[i] + input_window_index[i];
1788 DCHECK_GE(input_index[i], 0);
1789 DCHECK_LT(input_index[i], operand_shape.dimensions(i));
1790 }
1791 TF_RETURN_IF_ERROR(
1792 result.CopyElementFrom(operand, input_index, output_index));
1793 return true;
1794 };
1795
1796 auto gather_outer_loop_body =
1797 [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1798 TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
1799 output_batch_index_to_input_index(output_gather_index));
1800 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1801 shape, offset_indices_iteration_space,
1802 std::bind(gather_inner_loop_body, std::placeholders::_1,
1803 input_gather_index, output_gather_index)));
1804 return true;
1805 };
1806
1807 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1808 shape, start_indices_iteration_space, gather_outer_loop_body));
1809 evaluated_[gather] = std::move(result);
1810 return Status::OK();
1811 }
1812
HandleBroadcast(HloInstruction * broadcast)1813 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
1814 const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
1815
1816 TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
1817 << "broadcast dimensions is of size: " << broadcast->dimensions().size()
1818 << " and rank of operand_to_broadcast is: " << operand.shape().rank();
1819 // Checks that operand's dimensions are the same as the broadcast's
1820 // dimensions along the dimensions to be broadcasted.
1821 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
1822 auto operand_dim_size = operand.shape().dimensions(i);
1823 auto broadcast_dim_size =
1824 broadcast->shape().dimensions(broadcast->dimensions(i));
1825 TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
1826 "Operand dimension %d is broadcast to output dimension %d, but the "
1827 "sizes of these two dims do not match (%d vs %d): %s",
1828 i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
1829 broadcast->ToString());
1830 }
1831
1832 TF_ASSIGN_OR_RETURN(
1833 evaluated_[broadcast],
1834 operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
1835
1836 return Status::OK();
1837 }
1838
HandleAfterAll(HloInstruction * after_all)1839 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
1840 evaluated_[after_all] = LiteralUtil::CreateToken();
1841 return Status::OK();
1842 }
1843
HandleAddDependency(HloInstruction * add_dependency)1844 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
1845 // AddDedendency just forwards its zero-th operand.
1846 evaluated_[add_dependency] =
1847 GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
1848 return Status::OK();
1849 }
1850
HandleGetTupleElement(HloInstruction * get_tuple_element)1851 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
1852 const auto result_shape = get_tuple_element->shape();
1853 const int64 index = get_tuple_element->tuple_index();
1854
1855 auto operand = get_tuple_element->operand(0);
1856 TF_ASSIGN_OR_RETURN(
1857 auto inferred_return_shape,
1858 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
1859 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1860 << "return shape set to: " << ShapeUtil::HumanString(result_shape)
1861 << " but is inferred to be: "
1862 << ShapeUtil::HumanString(inferred_return_shape);
1863
1864 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1865
1866 evaluated_[get_tuple_element] =
1867 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
1868 return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
1869 /*dest_shape_index=*/{},
1870 /*src_shape_index=*/{index});
1871 }
1872
HandleCopy(HloInstruction * copy)1873 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
1874 TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
1875 evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
1876 return Status::OK();
1877 }
1878
HandleCopyStart(HloInstruction * copy_start)1879 Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
1880 if (copy_start->user_count() != 1 ||
1881 copy_start->users().at(0)->opcode() != HloOpcode::kCopyDone) {
1882 return tensorflow::errors::FailedPrecondition(
1883 "Cannot evaluate a kCopyStart that doesn't have a single kCopyDone "
1884 "user.");
1885 }
1886
1887 // The context in index {2} is undefined, but since we can't represent
1888 // undefined values using a Literal, we just use 0. This should be safe though
1889 // since we ensure that the only user of a kCopyStart is a kCopyDone which
1890 // consumes the context. Also note that MakeTuple copies its arguments, so
1891 // this is memory-safe.
1892 const Literal context_literal = LiteralUtil::CreateR0<uint32>(0);
1893 evaluated_[copy_start] = LiteralUtil::MakeTuple(
1894 {&GetEvaluatedLiteralFor(copy_start->operand(0)),
1895 &GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
1896 return Status::OK();
1897 }
1898
HandleCopyDone(HloInstruction * copy_done)1899 Status HloEvaluator::HandleCopyDone(HloInstruction* copy_done) {
1900 const HloInstruction* operand = copy_done->operand(0);
1901 if (operand->opcode() != HloOpcode::kCopyStart) {
1902 return tensorflow::errors::FailedPrecondition(
1903 "Cannot evaluate a kCopyDone that doesn't have a kCopyStart as "
1904 "operand.");
1905 }
1906
1907 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1908
1909 evaluated_[copy_done] =
1910 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), /*index=*/0));
1911 TF_RETURN_IF_ERROR(evaluated_[copy_done].CopyFrom(operand_tuple_literal,
1912 /*dest_shape_index=*/{},
1913 /*src_shape_index=*/{0}));
1914 return Status::OK();
1915 }
1916
HandleCall(HloInstruction * call)1917 Status HloEvaluator::HandleCall(HloInstruction* call) {
1918 auto* computation = call->to_apply();
1919 auto operands = call->operands();
1920
1921 std::vector<const Literal*> arg_literals;
1922 arg_literals.reserve(operands.size());
1923 for (auto operand : operands) {
1924 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1925 arg_literals.push_back(&arg_literal);
1926 }
1927
1928 HloEvaluator embedded_evaluator;
1929 embedded_evaluator.set_dynamic_dimension_inference(
1930 dynamic_dimension_inference_);
1931 TF_ASSIGN_OR_RETURN(Literal result,
1932 embedded_evaluator.Evaluate(*computation, arg_literals));
1933
1934 evaluated_[call] = std::move(result);
1935 return Status::OK();
1936 }
1937
HandleFusion(HloInstruction * fusion)1938 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
1939 HloModuleConfig config;
1940 // Attach cloned computation to an empty HLO module so the existing ones are
1941 // not modified.
1942 HloModule empty_hlo_module("EmptyModuleForFusion", config);
1943 HloCloneContext context(&empty_hlo_module);
1944 auto cloned_fused_computation =
1945 fusion->fused_instructions_computation()->Clone(
1946 /*suffix=*/"clone_with_layout", &context);
1947 for (auto* instruction : cloned_fused_computation->instructions()) {
1948 if (!LayoutUtil::HasLayout(instruction->shape())) {
1949 LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
1950 }
1951 }
1952 auto readded_computation =
1953 empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
1954
1955 auto operands = fusion->operands();
1956 std::vector<const Literal*> arg_literals;
1957 arg_literals.reserve(operands.size());
1958 for (auto operand : operands) {
1959 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1960 arg_literals.push_back(&arg_literal);
1961 }
1962
1963 HloEvaluator embedded_evaluator;
1964 embedded_evaluator.set_dynamic_dimension_inference(
1965 dynamic_dimension_inference_);
1966 TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
1967 *readded_computation, arg_literals));
1968
1969 evaluated_[fusion] = std::move(result);
1970 return Status::OK();
1971 }
1972
HandleConditional(HloInstruction * conditional)1973 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
1974 const auto& branch_index_literal =
1975 GetEvaluatedLiteralFor(conditional->operand(0));
1976 int branch_index;
1977 if (conditional->operand(0)->shape().element_type() == PRED) {
1978 branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
1979 } else {
1980 branch_index = branch_index_literal.Get<int32>({});
1981 if (branch_index < 0 || branch_index >= conditional->branch_count()) {
1982 branch_index = conditional->branch_count() - 1;
1983 }
1984 }
1985 const auto& branch_computation_arg =
1986 GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
1987
1988 HloEvaluator embedded_evaluator;
1989 embedded_evaluator.set_dynamic_dimension_inference(
1990 dynamic_dimension_inference_);
1991 TF_ASSIGN_OR_RETURN(Literal result,
1992 embedded_evaluator.Evaluate(
1993 *conditional->branch_computation(branch_index),
1994 {&branch_computation_arg}));
1995
1996 evaluated_[conditional] = std::move(result);
1997 return Status::OK();
1998 }
1999
HandleSelect(HloInstruction * select)2000 Status HloEvaluator::HandleSelect(HloInstruction* select) {
2001 const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
2002 const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
2003 const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
2004
2005 // If predicate is of scalar type, no element-wise selection would be needed.
2006 if (ShapeUtil::IsScalar(pred.shape())) {
2007 if (pred.Get<bool>({})) {
2008 evaluated_[select] = on_true.Clone();
2009 } else {
2010 evaluated_[select] = on_false.Clone();
2011 }
2012 return Status::OK();
2013 }
2014
2015 return DefaultAction(select);
2016 }
2017
HandleTupleSelect(HloInstruction * tuple_select)2018 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
2019 const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
2020 const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
2021 const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
2022
2023 if (pred.Get<bool>({})) {
2024 evaluated_[tuple_select] = on_true.Clone();
2025 } else {
2026 evaluated_[tuple_select] = on_false.Clone();
2027 }
2028 return Status::OK();
2029 }
2030
HandleWhile(HloInstruction * while_hlo)2031 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
2032 HloComputation* cond_comp = while_hlo->while_condition();
2033 HloComputation* body_comp = while_hlo->while_body();
2034 // Initialize the loop carried valued with the input to the While instruction.
2035 auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
2036 bool keep_going = true;
2037 int64 iteration_count = 0;
2038 HloEvaluator cond_evaluator(max_loop_iterations_);
2039 cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
2040 HloEvaluator loop_body_evaluator(max_loop_iterations_);
2041 loop_body_evaluator.set_dynamic_dimension_inference(
2042 dynamic_dimension_inference_);
2043 while (keep_going) {
2044 if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
2045 return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
2046 while_hlo->name(), max_loop_iterations_);
2047 }
2048 TF_ASSIGN_OR_RETURN(auto cond_val,
2049 cond_evaluator.Evaluate(*cond_comp, {&lcv}));
2050 keep_going = cond_val.GetFirstElement<bool>();
2051 if (keep_going) {
2052 TF_ASSIGN_OR_RETURN(auto body_val,
2053 loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
2054 VLOG(3) << "Loop iteration result: " << body_val.ToString();
2055 lcv = std::move(body_val);
2056 cond_evaluator.ResetVisitStates();
2057 loop_body_evaluator.ResetVisitStates();
2058 }
2059 }
2060 evaluated_[while_hlo] = std::move(lcv);
2061 return Status::OK();
2062 }
2063
2064 namespace {
2065 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar)2066 Literal ExtractLiteralFromIndexPositions(const Literal& from,
2067 absl::Span<int64 const> indices,
2068 bool extract_as_scalar) {
2069 if (extract_as_scalar) {
2070 return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
2071 }
2072 // We use a InlinedVector here because we need to convert it to an
2073 // absl::Span later, and this would not work with std::vector<bool>.
2074 absl::InlinedVector<NativeT, 10> values;
2075 for (int64 index : indices) {
2076 values.push_back(from.Get<NativeT>({index}));
2077 }
2078 return LiteralUtil::CreateR1<NativeT>(values);
2079 }
2080
ExtractFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar=false)2081 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
2082 absl::Span<int64 const> indices,
2083 bool extract_as_scalar = false) {
2084 if (extract_as_scalar) {
2085 CHECK_EQ(indices.size(), 1);
2086 }
2087 PrimitiveType type = from.shape().element_type();
2088 switch (type) {
2089 case PRED: {
2090 return ExtractLiteralFromIndexPositions<bool>(from, indices,
2091 extract_as_scalar);
2092 }
2093 case U8: {
2094 return ExtractLiteralFromIndexPositions<uint8>(from, indices,
2095 extract_as_scalar);
2096 }
2097 case S8: {
2098 return ExtractLiteralFromIndexPositions<int8>(from, indices,
2099 extract_as_scalar);
2100 }
2101 case BF16: {
2102 return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
2103 extract_as_scalar);
2104 }
2105 case F16: {
2106 return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
2107 extract_as_scalar);
2108 }
2109 case U16: {
2110 return ExtractLiteralFromIndexPositions<uint16>(from, indices,
2111 extract_as_scalar);
2112 }
2113 case S16: {
2114 return ExtractLiteralFromIndexPositions<int16>(from, indices,
2115 extract_as_scalar);
2116 }
2117 case F32: {
2118 return ExtractLiteralFromIndexPositions<float>(from, indices,
2119 extract_as_scalar);
2120 }
2121 case U32: {
2122 return ExtractLiteralFromIndexPositions<uint32>(from, indices,
2123 extract_as_scalar);
2124 }
2125 case S32: {
2126 return ExtractLiteralFromIndexPositions<int32>(from, indices,
2127 extract_as_scalar);
2128 }
2129 case F64: {
2130 return ExtractLiteralFromIndexPositions<double>(from, indices,
2131 extract_as_scalar);
2132 }
2133 case C64: {
2134 return ExtractLiteralFromIndexPositions<std::complex<float>>(
2135 from, indices, extract_as_scalar);
2136 }
2137 case U64: {
2138 return ExtractLiteralFromIndexPositions<uint64>(from, indices,
2139 extract_as_scalar);
2140 }
2141 case S64: {
2142 return ExtractLiteralFromIndexPositions<int64>(from, indices,
2143 extract_as_scalar);
2144 }
2145 case C128: {
2146 return ExtractLiteralFromIndexPositions<std::complex<double>>(
2147 from, indices, extract_as_scalar);
2148 }
2149 default:
2150 return InvalidArgument("Unsupported type for Sort: %s",
2151 PrimitiveType_Name(type));
2152 }
2153 }
2154 } // namespace
2155
HandleSort(HloInstruction * sort)2156 Status HloEvaluator::HandleSort(HloInstruction* sort) {
2157 TF_RET_CHECK(sort->operand_count() >= 1)
2158 << "Expected at least 1 operand for sort";
2159 for (int64 i = 1; i < sort->operand_count(); ++i) {
2160 TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
2161 sort->operand(i)->shape()))
2162 << "All Sort operands must have the same dimensions";
2163 }
2164
2165 if (VLOG_IS_ON(3)) {
2166 for (int64 i = 0; i < sort->operand_count(); ++i) {
2167 VLOG(3) << "HandleSort operand " << i << " literal: "
2168 << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
2169 }
2170 }
2171 Shape key_shape = sort->operand(0)->shape();
2172 auto rank = key_shape.rank();
2173 std::vector<Literal> result_literals;
2174 result_literals.reserve(sort->operand_count());
2175 for (int64 i = 0; i < sort->operand_count(); ++i) {
2176 result_literals.emplace_back(sort->operand(i)->shape());
2177 }
2178 std::vector<int64> zero_base(rank, 0);
2179 std::vector<int64> increment(rank, 1);
2180 int64 sort_dim = sort->dimensions(0);
2181 int64 sort_dim_elements = key_shape.dimensions(sort_dim);
2182 increment[sort_dim] = sort_dim_elements;
2183 HloEvaluator embedded_evaluator(max_loop_iterations_);
2184 // Iterate through each dimension except 'sort_dim'.
2185 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2186 key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
2187 [&](absl::Span<const int64> indices) -> StatusOr<bool> {
2188 // Extract a slice from each operand literal that corresponds to
2189 // exactly the row in dimension 'sort_dim'.
2190 std::vector<int64> limit_indices(indices.begin(), indices.end());
2191 absl::c_for_each(limit_indices, [](int64& index) { ++index; });
2192 limit_indices[sort_dim] = sort_dim_elements;
2193 std::vector<Literal> literals_to_sort;
2194 literals_to_sort.reserve(sort->operand_count());
2195 for (int64 i = 0; i < sort->operand_count(); ++i) {
2196 TF_ASSIGN_OR_RETURN(auto literal_to_sort,
2197 GetEvaluatedLiteralFor(sort->operand(i))
2198 .Slice(indices, limit_indices)
2199 .Reshape({sort_dim_elements}));
2200 literals_to_sort.push_back(std::move(literal_to_sort));
2201 }
2202 std::vector<int64> indices_to_sort(sort_dim_elements);
2203 std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
2204 Status compare_status = Status::OK();
2205 auto comparator = [sort, &compare_status, &embedded_evaluator,
2206 &literals_to_sort](int64 a, int64 b) {
2207 std::vector<Literal> literals;
2208 literals.reserve(2 * sort->operand_count());
2209 for (int64 i = 0; i < sort->operand_count(); ++i) {
2210 auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
2211 /*extract_as_scalar=*/true);
2212 if (!lhs.ok()) {
2213 compare_status = lhs.status();
2214 return false;
2215 }
2216 literals.push_back(std::move(lhs.ValueOrDie()));
2217 auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
2218 /*extract_as_scalar=*/true);
2219 if (!rhs.ok()) {
2220 compare_status = rhs.status();
2221 return false;
2222 }
2223 literals.push_back(std::move(rhs.ValueOrDie()));
2224 }
2225 std::vector<const Literal*> literal_ptrs;
2226 absl::c_transform(literals, std::back_inserter(literal_ptrs),
2227 [](const Literal& literal) { return &literal; });
2228
2229 auto computed_result =
2230 embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
2231 // Clear visit states so that we can use the evaluator again
2232 // on the same computation.
2233 embedded_evaluator.ResetVisitStates();
2234 if (!computed_result.ok()) {
2235 compare_status = computed_result.status();
2236 return false;
2237 }
2238 return computed_result.ValueOrDie().Get<bool>({});
2239 };
2240 if (Cast<HloSortInstruction>(sort)->is_stable()) {
2241 std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
2242 comparator);
2243 } else {
2244 std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
2245 }
2246 if (!compare_status.ok()) {
2247 return compare_status;
2248 }
2249 std::vector<int64> slice_dimensions(rank, 1);
2250 slice_dimensions[sort_dim] = sort_dim_elements;
2251 std::vector<int64> start_indices(rank, 0);
2252 for (int64 i = 0; i < sort->operand_count(); ++i) {
2253 TF_ASSIGN_OR_RETURN(
2254 Literal sorted_literal,
2255 ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
2256 TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
2257 sorted_literal.Reshape(slice_dimensions));
2258 TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
2259 sorted_literal_reshaped, start_indices, indices,
2260 slice_dimensions));
2261 }
2262 return true;
2263 }));
2264
2265 if (sort->operand_count() == 1) {
2266 evaluated_[sort] = std::move(result_literals[0]);
2267 } else {
2268 std::vector<const Literal*> literal_ptrs;
2269 absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
2270 [](const Literal& literal) { return &literal; });
2271
2272 Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
2273 VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
2274
2275 evaluated_[sort] = std::move(result_tuple);
2276 }
2277 return Status::OK();
2278 }
2279
IsScalarAdd(HloComputation * computation)2280 static bool IsScalarAdd(HloComputation* computation) {
2281 HloInstruction* instruction = computation->root_instruction();
2282 if (instruction->opcode() == HloOpcode::kAdd &&
2283 computation->num_parameters() == 2) {
2284 const HloInstruction* lhs = instruction->operand(0);
2285 const HloInstruction* rhs = instruction->operand(1);
2286 return lhs->opcode() == HloOpcode::kParameter &&
2287 ShapeUtil::IsScalar(lhs->shape()) &&
2288 rhs->opcode() == HloOpcode::kParameter &&
2289 ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
2290 }
2291 return false;
2292 }
2293
2294 // Run a single step of an inner loop while running reduction, which applies
2295 // the user-provided computation on the accumulator and the output element
2296 // (until the reduction is completed, the output element is also used as
2297 // an accumulator).
PerformReductionStep(bool is_tuple,absl::Span<const int64> input_index,absl::Span<const int64> output_index,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * computation,HloEvaluator * embedded_evaluator)2298 static StatusOr<bool> PerformReductionStep(
2299 bool is_tuple, absl::Span<const int64> input_index,
2300 absl::Span<const int64> output_index,
2301 absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2302 HloComputation* computation, HloEvaluator* embedded_evaluator) {
2303 int num_args = results.size();
2304
2305 absl::InlinedVector<Literal, 1> arg_values;
2306 arg_values.reserve(num_args);
2307 absl::InlinedVector<Literal, 1> accumulators;
2308 accumulators.reserve(num_args);
2309 for (int64 i = 0; i < num_args; ++i) {
2310 arg_values.emplace_back(
2311 ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2312 accumulators.emplace_back(
2313 ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2314
2315 TF_RETURN_IF_ERROR(
2316 arg_values[i].CopyElementFrom(*input_args[i], input_index, {}));
2317 TF_RETURN_IF_ERROR(
2318 accumulators[i].CopyElementFrom(results[i], output_index, {}));
2319 }
2320
2321 // Evaluate computation with specified literal operands.
2322 absl::InlinedVector<Literal*, 2> embedded_operands;
2323 for (Literal& accumulator : accumulators) {
2324 embedded_operands.push_back(&accumulator);
2325 }
2326 for (Literal& local_input : arg_values) {
2327 embedded_operands.push_back(&local_input);
2328 }
2329
2330 TF_ASSIGN_OR_RETURN(
2331 Literal computed_result,
2332 embedded_evaluator->Evaluate(*computation, embedded_operands));
2333
2334 // Clear visit states so that we can use the evaluator again on the same
2335 // computation.
2336 embedded_evaluator->ResetVisitStates();
2337
2338 if (is_tuple) {
2339 std::vector<Literal> computed_results = computed_result.DecomposeTuple();
2340 for (int64 i = 0; i < num_args; ++i) {
2341 TF_RETURN_IF_ERROR(
2342 results[i].CopyElementFrom(computed_results[i], {}, output_index));
2343 }
2344 } else {
2345 TF_RETURN_IF_ERROR(
2346 results[0].CopyElementFrom(computed_result, {}, output_index));
2347 }
2348
2349 return true;
2350 }
2351
GenerateReduceOutputElement(bool is_tuple,absl::Span<const int64> output_index,absl::Span<const Literal * const> init_values,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * function,HloEvaluator * embedded_evaluator,absl::Span<const int64> arg_dim_steps,absl::Span<const int64> arg_dim_counts,absl::Span<const int64> result_to_arg_index)2352 static StatusOr<bool> GenerateReduceOutputElement(
2353 bool is_tuple, absl::Span<const int64> output_index,
2354
2355 absl::Span<const Literal* const> init_values,
2356 absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2357
2358 HloComputation* function, HloEvaluator* embedded_evaluator,
2359
2360 absl::Span<const int64> arg_dim_steps,
2361 absl::Span<const int64> arg_dim_counts,
2362 absl::Span<const int64> result_to_arg_index) {
2363 bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) &&
2364 IsScalarAdd(function) && !is_tuple;
2365
2366 const Shape& arg_shape = input_args[0]->shape();
2367 absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2368 std::vector<int64> base(arg_dimensions.size());
2369 for (int64 i = 0; i < output_index.size(); ++i) {
2370 base[result_to_arg_index[i]] = output_index[i];
2371 }
2372
2373 for (int64 i = 0; i < results.size(); ++i) {
2374 TF_RETURN_IF_ERROR(
2375 results[i].CopyElementFrom(*init_values[i], {}, output_index));
2376 }
2377
2378 if (use_fast_add) {
2379 double computed_result = *init_values[0]->GetAsDouble({});
2380 auto reduction_step =
2381 [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
2382 double argument = *input_args[0]->GetAsDouble(input_index);
2383 computed_result += argument;
2384 return true;
2385 };
2386 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2387 arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step));
2388 TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result));
2389 return true;
2390 }
2391
2392 // Iterates only over reduced shape, as counts and steps are set to zero
2393 // for all non-reduced dimensions.
2394 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2395 arg_shape, base, arg_dim_counts, arg_dim_steps,
2396 [&](absl::Span<const int64> input_index) {
2397 return PerformReductionStep(is_tuple, input_index, output_index,
2398 input_args, results, function,
2399 embedded_evaluator);
2400 }));
2401 return true;
2402 }
2403
HandleReduce(HloInstruction * instr)2404 Status HloEvaluator::HandleReduce(HloInstruction* instr) {
2405 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
2406 int64 num_args = reduce->inputs().size();
2407 absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
2408 HloComputation* function = reduce->to_apply();
2409
2410 absl::InlinedVector<const Shape*, 1> operand_shapes;
2411 for (const HloInstruction* operand : reduce->operands()) {
2412 operand_shapes.push_back(&operand->shape());
2413 }
2414 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
2415 ShapeInference::InferReduceShape(
2416 operand_shapes, dimensions_to_reduce,
2417 /*to_apply=*/function->ComputeProgramShape()));
2418 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(),
2419 inferred_return_shape))
2420 << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
2421 << " but is inferred to be: "
2422 << ShapeUtil::HumanString(inferred_return_shape);
2423
2424 absl::InlinedVector<const Literal*, 1> input_args(num_args);
2425 absl::InlinedVector<const Literal*, 1> init_values(num_args);
2426 for (int64 i = 0; i < num_args; ++i) {
2427 input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]);
2428 VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString();
2429 init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]);
2430 VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString();
2431 TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape()));
2432 }
2433
2434 // All args and results have the same dimensions, so pick an arbitrary one.
2435 const Shape& arg_shape = input_args[0]->shape();
2436 const Shape& out_shape = inferred_return_shape;
2437 bool is_tuple = out_shape.IsTuple();
2438 const Shape& output_shape = inferred_return_shape.IsTuple()
2439 ? inferred_return_shape.tuple_shapes(0)
2440 : inferred_return_shape;
2441
2442 absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2443
2444 // All increments are set to 0.
2445 std::vector<int64> arg_dim_steps(arg_dimensions.size());
2446
2447 // All counts are set to 0.
2448 std::vector<int64> arg_dim_counts(arg_dimensions.size());
2449
2450 // Set steps and counts for reduced dimensions.
2451 // This avoids iterating over non-reduced dimensions, as their step
2452 // and count is set to zero.
2453 for (const int64 dim : dimensions_to_reduce) {
2454 arg_dim_steps[dim] = 1;
2455 arg_dim_counts[dim] = arg_dimensions[dim];
2456 }
2457
2458 // Map each dimension in the result to a dimension in arg that isn't
2459 // being reduced.
2460 std::vector<int64> result_to_arg_index;
2461 for (int64 i = 0; i < arg_dimensions.size(); ++i) {
2462 if (arg_dim_steps[i] == 0) {
2463 result_to_arg_index.push_back(i);
2464 }
2465 }
2466
2467 HloEvaluator embedded_evaluator(max_loop_iterations_);
2468 absl::InlinedVector<Literal, 1> results(num_args);
2469 for (int64 i = 0; i < num_args; ++i) {
2470 results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape);
2471 }
2472
2473 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2474 output_shape, [&](absl::Span<const int64> output_index) {
2475 return GenerateReduceOutputElement(
2476 is_tuple, output_index, init_values, input_args,
2477 absl::Span<Literal>(results), function, &embedded_evaluator,
2478 arg_dim_steps, arg_dim_counts, result_to_arg_index);
2479 }));
2480
2481 if (is_tuple) {
2482 Literal tuple_result(inferred_return_shape);
2483 for (int64 i = 0; i < num_args; ++i) {
2484 TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
2485 }
2486 evaluated_[reduce] = std::move(tuple_result);
2487 } else {
2488 CHECK_EQ(results.size(), 1);
2489 evaluated_[reduce] = std::move(results[0]);
2490 }
2491 if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) {
2492 TF_ASSIGN_OR_RETURN(evaluated_[reduce],
2493 evaluated_[reduce].ConvertToShape(reduce->shape()));
2494 }
2495 return Status::OK();
2496 }
2497
HandleReduceWindow(HloInstruction * hlo)2498 Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
2499 // Here we delegate the handling to the typed visitor class, instantiated by
2500 // using the type of the first input of ReduceWindow. The support for the
2501 // variadic case inside the typed_visitor is made to not use the template
2502 // parameter so it doesn't really matter which type is used to instantiate it
2503 // here. We choose not to move the implementation for handle ReduceWindow
2504 // from the typed visitor to here because we need to reuse the
2505 // IterateThroughWindow method, which is defined and only avaiable inside the
2506 // typed visitor.
2507 if (hlo->shape().IsTuple()) {
2508 return hlo->Visit(
2509 typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
2510 } else {
2511 return DefaultAction(hlo);
2512 }
2513 }
2514
HandleCustomCall(HloInstruction * custom_call)2515 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
2516 if (!custom_call_handler_) {
2517 // No handler is registered; this means custom-calls are not allowed.
2518 return DefaultAction(custom_call);
2519 }
2520
2521 // Evaluate input operands so the handler has access to the operand data.
2522 std::vector<const Literal*> operands;
2523 operands.reserve(custom_call->operand_count());
2524 for (const HloInstruction* operand : custom_call->operands()) {
2525 operands.push_back(&GetEvaluatedLiteralFor(operand));
2526 }
2527
2528 // Synchronously issue the handler to populate the instruction output literal.
2529 TF_ASSIGN_OR_RETURN(
2530 auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
2531
2532 evaluated_[custom_call] = std::move(output);
2533 return Status::OK();
2534 }
2535
Preprocess(HloInstruction * hlo)2536 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
2537 VLOG(2) << "About to visit HLO: " << hlo->ToString();
2538 return ShapeUtil::ValidateShape(hlo->shape());
2539 }
2540
Postprocess(HloInstruction * hlo)2541 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
2542 VLOG(2) << "Finished visiting " << hlo->ToString()
2543 << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
2544 // Out of convenience the literal may have been produced with a different
2545 // layout. Relayout as indicated by the HLO instruction.
2546 if (!Layout::Equal().MinorToMajorOnly()(
2547 GetEvaluatedLiteralFor(hlo).shape().layout(),
2548 hlo->shape().layout())) {
2549 evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
2550 }
2551 return Status::OK();
2552 }
2553
2554 namespace {
2555 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64 m,int64 n,int64 k,int32 transpose_lhs,int32 transpose_rhs)> & impl_fn)2556 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
2557 const Array2D<T>& lhs, const Array2D<T>& rhs,
2558 const std::function<void(
2559 const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n,
2560 int64 k, int32 transpose_lhs, int32 transpose_rhs)>& impl_fn) {
2561 CHECK_EQ(lhs.width(), rhs.height());
2562 int m = lhs.height();
2563 int n = rhs.width();
2564 int k = lhs.width();
2565 auto result = absl::make_unique<Array2D<T>>(m, n);
2566 // Because Eigen is a header-oriented library, make sure that the Eigen code
2567 // is the same as the code used by the CPU backend (otherwise the linker will
2568 // randomly pick *some* definition).
2569 impl_fn(
2570 /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
2571 k,
2572 /*transpose_lhs=*/0,
2573 /*transpose_rhs=*/0);
2574 return result;
2575 }
2576 } // namespace
2577
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)2578 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
2579 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
2580 return MatmulArray2DImpl<Eigen::half>(
2581 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
2582 }
2583
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)2584 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
2585 const Array2D<float>& lhs, const Array2D<float>& rhs) {
2586 return MatmulArray2DImpl<float>(
2587 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
2588 }
2589
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)2590 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
2591 const Array2D<double>& lhs, const Array2D<double>& rhs) {
2592 return MatmulArray2DImpl<double>(
2593 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
2594 }
2595
MatmulArray2D(const Array2D<std::complex<float>> & lhs,const Array2D<std::complex<float>> & rhs)2596 std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
2597 const Array2D<std::complex<float>>& lhs,
2598 const Array2D<std::complex<float>>& rhs) {
2599 return MatmulArray2DImpl<std::complex<float>>(
2600 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
2601 }
2602
MatmulArray2D(const Array2D<std::complex<double>> & lhs,const Array2D<std::complex<double>> & rhs)2603 std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
2604 const Array2D<std::complex<double>>& lhs,
2605 const Array2D<std::complex<double>>& rhs) {
2606 return MatmulArray2DImpl<std::complex<double>>(
2607 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
2608 }
2609
MatmulArray2D(const Array2D<int32> & lhs,const Array2D<int32> & rhs)2610 std::unique_ptr<Array2D<int32>> HloEvaluator::MatmulArray2D(
2611 const Array2D<int32>& lhs, const Array2D<int32>& rhs) {
2612 return MatmulArray2DImpl<int32>(
2613 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulS32);
2614 }
2615
2616 } // namespace xla
2617