1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdlib>
20 #include <functional>
21 #include <iterator>
22 #include <string>
23 #include <type_traits>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/index_util.h"
31 #include "tensorflow/compiler/xla/layout_util.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/map_util.h"
34 #include "tensorflow/compiler/xla/primitive_util.h"
35 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
36 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
37 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/hlo_query.h"
41 #include "tensorflow/compiler/xla/service/shape_inference.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/core/lib/core/bitmap.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/protobuf.h"
52 #include "tensorflow/core/platform/types.h"
53
54 namespace xla {
55
56 namespace {
57
58 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)59 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
60 LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
61 std::function<bool(OperandT, OperandT)> compare_op;
62 switch (direction) {
63 case ComparisonDirection::kEq:
64 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
65 return lhs_el == rhs_el;
66 };
67 break;
68 case ComparisonDirection::kNe:
69 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
70 return lhs_el != rhs_el;
71 };
72 break;
73 case ComparisonDirection::kGe:
74 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
75 return lhs_el >= rhs_el;
76 };
77 break;
78 case ComparisonDirection::kGt:
79 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
80 return lhs_el > rhs_el;
81 };
82 break;
83 case ComparisonDirection::kLe:
84 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
85 return lhs_el <= rhs_el;
86 };
87 break;
88 case ComparisonDirection::kLt:
89 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
90 return lhs_el < rhs_el;
91 };
92 break;
93 }
94
95 Literal result(shape);
96 TF_RETURN_IF_ERROR(
97 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
98 return compare_op(lhs_literal.Get<OperandT>(multi_index),
99 rhs_literal.Get<OperandT>(multi_index));
100 }));
101
102 return std::move(result);
103 }
104
105 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)106 StatusOr<Literal> Compare<complex64>(const Shape& shape,
107 ComparisonDirection direction,
108 LiteralSlice lhs_literal,
109 LiteralSlice rhs_literal) {
110 std::function<bool(complex64, complex64)> compare_op;
111 switch (direction) {
112 case ComparisonDirection::kEq:
113 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
114 return lhs_el == rhs_el;
115 };
116 break;
117 case ComparisonDirection::kNe:
118 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
119 return lhs_el != rhs_el;
120 };
121 break;
122 default:
123 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
124 << ComparisonDirectionToString(direction);
125 }
126
127 Literal result(shape);
128 TF_RETURN_IF_ERROR(
129 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
130 return compare_op(lhs_literal.Get<complex64>(multi_index),
131 rhs_literal.Get<complex64>(multi_index));
132 }));
133
134 return std::move(result);
135 }
136
137 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)138 StatusOr<Literal> Compare<complex128>(const Shape& shape,
139 ComparisonDirection direction,
140 LiteralSlice lhs_literal,
141 LiteralSlice rhs_literal) {
142 std::function<bool(complex128, complex128)> compare_op;
143 switch (direction) {
144 case ComparisonDirection::kEq:
145 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
146 return lhs_el == rhs_el;
147 };
148 break;
149 case ComparisonDirection::kNe:
150 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
151 return lhs_el != rhs_el;
152 };
153 break;
154 default:
155 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
156 << ComparisonDirectionToString(direction);
157 }
158
159 Literal result(shape);
160 TF_RETURN_IF_ERROR(
161 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
162 return compare_op(lhs_literal.Get<complex128>(multi_index),
163 rhs_literal.Get<complex128>(multi_index));
164 }));
165
166 return std::move(result);
167 }
168
169 } // namespace
170
171 // Note that unsupported types by the typed visitor does not necessarily imply
172 // the non-typed HloEvaluator (parent evaluator) would not support them either
173 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
174 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
175 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64 max_loop_iterations)176 HloEvaluator::HloEvaluator(int64 max_loop_iterations)
177 : max_loop_iterations_(max_loop_iterations) {
178 typed_visitors_[PRED] =
179 absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
180 typed_visitors_[U8] =
181 absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
182 typed_visitors_[U16] =
183 absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
184 typed_visitors_[U32] =
185 absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
186 typed_visitors_[U64] =
187 absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
188 typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
189 typed_visitors_[S16] =
190 absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
191 typed_visitors_[S32] =
192 absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
193 typed_visitors_[S64] =
194 absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
195 typed_visitors_[F16] =
196 absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
197 typed_visitors_[F32] =
198 absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
199 typed_visitors_[F64] =
200 absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
201 typed_visitors_[C64] =
202 absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
203 typed_visitors_[C128] =
204 absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
205
206 // Most of the evaluator computations we use don't support BF16 (e.g.,
207 // std::ceil, std::tanh). To make evaluator work with BF16, we set all
208 // elementwise computations to be done in F32 and do BF16<->F32 conversion
209 // around the input and the output of the computations.
210 typed_visitors_[BF16] =
211 absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
212
213 typed_visitors_[TUPLE] =
214 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
215 return Unimplemented(
216 "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
217 });
218 typed_visitors_[OPAQUE] =
219 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
220 return Unimplemented(
221 "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
222 });
223 typed_visitors_[TOKEN] =
224 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
225 return Unimplemented(
226 "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
227 });
228 }
229
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)230 StatusOr<Literal> HloEvaluator::Evaluate(
231 const HloComputation& computation,
232 absl::Span<const Literal* const> arg_literals) {
233 CHECK(computation.parent() != nullptr);
234 XLA_VLOG_LINES(
235 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
236
237 if (arg_literals.size() != computation.num_parameters()) {
238 return InvalidArgument(
239 "Expected %d argument%s, but got %d.", computation.num_parameters(),
240 computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
241 }
242 for (int64 i = 0; i < arg_literals.size(); ++i) {
243 const auto& computation_shape =
244 computation.parameter_instruction(i)->shape();
245 const auto& arg_shape = arg_literals[i]->shape();
246 if (!ShapeUtil::Equal(computation_shape, arg_shape)) {
247 return InvalidArgument(
248 "Shape mismatch at parameter %d. Computation expected %s, but arg "
249 "was %s.",
250 i, ShapeUtil::HumanStringWithLayout(computation_shape),
251 ShapeUtil::HumanString(arg_shape));
252 }
253 }
254
255 evaluated_.clear();
256 arg_literals_.clear();
257 for (const auto& literal_ptr : arg_literals) {
258 arg_literals_.push_back(&*literal_ptr);
259 }
260
261 // Re-seed RNG, either from the configuration's seed or a monotonic
262 // per-evaluator seed (which prevents two evaluators from returning the same
263 // random sequence).
264 if (computation.parent()->config().seed()) {
265 seed_ = computation.parent()->config().seed();
266 } else {
267 // Start global_seed at a (true) random value.
268 static std::atomic<uint64> global_seed{std::random_device()()};
269 seed_ = global_seed.fetch_add(1);
270 }
271 engine_.seed(seed_);
272
273 TF_RETURN_IF_ERROR(computation.Accept(this));
274 return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
275 }
276
Evaluate(HloInstruction * instruction)277 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
278 if (instruction->opcode() == HloOpcode::kParameter) {
279 return tensorflow::errors::FailedPrecondition(
280 "Cannot evaluate a parameter.");
281 }
282 if (!hlo_query::AllOperandsAreConstants(*instruction)) {
283 return tensorflow::errors::FailedPrecondition(
284 "Not all operands are constants.");
285 }
286
287 arg_literals_.clear();
288 evaluated_.clear();
289
290 TF_RETURN_IF_ERROR(Preprocess(instruction));
291 TF_RETURN_IF_ERROR(instruction->Visit(this));
292 TF_RETURN_IF_ERROR(Postprocess(instruction));
293 return GetEvaluatedLiteralFor(instruction).Clone();
294 }
295
TryEvaluate(HloInstruction * instruction,Literal * result)296 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
297 CHECK(result != nullptr);
298 auto result_or = Evaluate(instruction);
299 if (!result_or.ok()) {
300 VLOG(1) << "TryEvaluate failed:" << result_or.status();
301 return false;
302 }
303
304 *result = result_or.ConsumeValueOrDie();
305 return true;
306 }
307
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)308 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
309 const HloInstruction* instruction,
310 const std::unordered_map<const HloInstruction*, const Literal*>&
311 substitutions) {
312 std::vector<std::unique_ptr<HloInstruction>> owned_operands;
313 for (const HloInstruction* operand : instruction->operands()) {
314 auto it = substitutions.find(operand);
315 if (it == substitutions.end()) {
316 owned_operands.push_back(operand->Clone());
317 } else {
318 owned_operands.push_back(
319 HloInstruction::CreateConstant(it->second->Clone()));
320 }
321 }
322
323 std::vector<HloInstruction*> operands;
324 operands.reserve(owned_operands.size());
325 for (auto& operand : owned_operands) {
326 operands.push_back(operand.get());
327 }
328
329 std::unique_ptr<HloInstruction> cloned_instruction =
330 instruction->CloneWithNewOperands(instruction->shape(), operands);
331 auto result = Evaluate(cloned_instruction.get());
332
333 return result;
334 }
335
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)336 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
337 HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
338 std::unique_ptr<HloInstruction> lhs_instr =
339 HloInstruction::CreateConstant(lhs.Clone());
340 std::unique_ptr<HloInstruction> rhs_instr =
341 HloInstruction::CreateConstant(rhs.Clone());
342
343 std::unique_ptr<HloInstruction> cloned_instruction =
344 HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
345 rhs_instr.get());
346 auto result = Evaluate(cloned_instruction.get());
347
348 return result;
349 }
350
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)351 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
352 HloOpcode opcode, const Literal& operand) {
353 std::unique_ptr<HloInstruction> operand_instr =
354 HloInstruction::CreateConstant(operand.Clone());
355
356 std::unique_ptr<HloInstruction> cloned_instruction =
357 HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
358 auto result = Evaluate(cloned_instruction.get());
359
360 return result;
361 }
362
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)363 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
364 const DotDimensionNumbers& dim_numbers,
365 const PrecisionConfig& precision_config, const Literal& lhs,
366 const Literal& rhs) {
367 std::unique_ptr<HloInstruction> lhs_instr =
368 HloInstruction::CreateConstant(lhs.Clone());
369 std::unique_ptr<HloInstruction> rhs_instr =
370 HloInstruction::CreateConstant(rhs.Clone());
371
372 TF_ASSIGN_OR_RETURN(
373 Shape dot_shape,
374 ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
375
376 std::unique_ptr<HloInstruction> cloned_instruction =
377 HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
378 dim_numbers, precision_config);
379 return Evaluate(cloned_instruction.get());
380 }
381
HandleBitcast(HloInstruction * bitcast)382 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
383 const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
384 Literal result(bitcast->shape());
385 TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
386 memcpy(result.untyped_data(), operand_literal.untyped_data(),
387 operand_literal.size_bytes());
388 evaluated_[bitcast] = std::move(result);
389 return Status::OK();
390 }
391
HandleGetDimensionSize(HloInstruction * get_dimension_size)392 Status HloEvaluator::HandleGetDimensionSize(
393 HloInstruction* get_dimension_size) {
394 HloInstruction* operand = get_dimension_size->mutable_operand(0);
395 int64 dim = get_dimension_size->dimension();
396 if (dynamic_dimension_inference_ == nullptr) {
397 return InvalidArgument(
398 "Evaluator cannot evaluate get_dimension_size without "
399 "set_dynamic_dimension_inference.");
400 }
401 HloInstruction* dynamic_size =
402 dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
403 if (dynamic_size != nullptr) {
404 evaluated_[get_dimension_size] =
405 GetEvaluatedLiteralFor(dynamic_size).Clone();
406 return Status::OK();
407 }
408
409 const Shape& shape = get_dimension_size->operand(0)->shape();
410 Literal output(ShapeUtil::MakeShape(U32, {}));
411 output.PopulateWithValue(
412 static_cast<uint32>(shape.dimensions(get_dimension_size->dimension())));
413 evaluated_[get_dimension_size] = std::move(output);
414 return Status::OK();
415 }
416
HandleParameter(HloInstruction * parameter)417 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
418 // Nothing to do other than sanity checks. Parameters' values are stored in
419 // arg_literals_.
420 CHECK_LT(parameter->parameter_number(), arg_literals_.size());
421
422 #ifndef NDEBUG
423 const Literal* input_literal = arg_literals_[parameter->parameter_number()];
424 VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
425 DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()))
426 << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape())
427 << ", but input literal shape is: "
428 << ShapeUtil::HumanString(input_literal->shape());
429 #endif
430
431 return Status::OK();
432 }
433
HandleConstant(HloInstruction *)434 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
435
HandleReshape(HloInstruction * reshape)436 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
437 TF_ASSIGN_OR_RETURN(
438 evaluated_[reshape],
439 GetEvaluatedLiteralFor(reshape->operand(0))
440 .Reshape(AsInt64Slice(reshape->shape().dimensions())));
441 return Status::OK();
442 }
443
HandleTranspose(HloInstruction * transpose)444 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
445 evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
446 .Transpose(transpose->dimensions());
447 return Status::OK();
448 }
449
HandleConcatenate(HloInstruction * concatenate)450 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
451 absl::Span<HloInstruction* const> operands(concatenate->operands());
452 // The result concatenate dimension is going to be the sum of all
453 // concatenate dimensions of the operands taking part of the operation.
454 const Shape& reference_shape = operands[0]->shape();
455 CHECK(reference_shape.IsArray());
456 const int64 rank = reference_shape.rank();
457 const int64 concat_dim = concatenate->dimensions()[0];
458 CHECK_GE(concat_dim, 0);
459 CHECK_LT(concat_dim, rank);
460
461 DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
462 reference_shape.dimensions().end());
463
464 for (int64 i = 1; i < operands.size(); ++i) {
465 const Shape& operand_shape = operands[i]->shape();
466 CHECK(operand_shape.IsArray());
467 // Accumulate the concat dimension from all tensors taking part to the
468 // operation.
469 concat_dimensions[concat_dim] +=
470 ShapeUtil::GetDimension(operand_shape, concat_dim);
471 }
472
473 auto result_literal = LiteralUtil::CreateFromDimensions(
474 reference_shape.element_type(), concat_dimensions);
475 DimensionVector source_indices(rank, 0);
476 DimensionVector dest_indices(concat_dimensions.size(), 0);
477
478 for (auto operand : operands) {
479 const Shape& operand_shape = operand->shape();
480 TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
481 GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
482 AsInt64Slice(operand_shape.dimensions())));
483 dest_indices[concat_dim] +=
484 ShapeUtil::GetDimension(operand_shape, concat_dim);
485 }
486
487 evaluated_[concatenate] = std::move(result_literal);
488 return Status::OK();
489 }
490
HandleIsFinite(HloInstruction * is_finite)491 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
492 auto operand = is_finite->operand(0);
493 auto elem_ty = operand->shape().element_type();
494 switch (elem_ty) {
495 case PRED:
496 case TUPLE:
497 case OPAQUE:
498 case TOKEN:
499 case S8:
500 case S16:
501 case S32:
502 case S64:
503 case U8:
504 case U16:
505 case U32:
506 case U64:
507 case C64:
508 case C128:
509 // Explicitly enumerate all types in this switch so that when we add a new
510 // type, we'll get a compile error here.
511 case PRIMITIVE_TYPE_INVALID:
512 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
513 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
514 return InvalidArgument(
515 "expected element type in shape to be floating point, but "
516 "got: %s",
517 PrimitiveType_Name(elem_ty));
518
519 case F16: {
520 auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
521 is_finite,
522 [](Eigen::half elem_operand) {
523 return std::isfinite(static_cast<float>(elem_operand));
524 },
525 GetEvaluatedLiteralFor(operand));
526 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
527 break;
528 }
529 case BF16: {
530 auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
531 is_finite,
532 [](bfloat16 elem_operand) {
533 return std::isfinite(static_cast<float>(elem_operand));
534 },
535 GetEvaluatedLiteralFor(operand));
536 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
537 break;
538 }
539 case F32: {
540 auto result_or = ElementWiseUnaryOpImpl<bool, float>(
541 is_finite,
542 [](float elem_operand) { return std::isfinite(elem_operand); },
543 GetEvaluatedLiteralFor(operand));
544 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
545 break;
546 }
547 case F64: {
548 auto result_or = ElementWiseUnaryOpImpl<bool, double>(
549 is_finite,
550 [](double elem_operand) { return std::isfinite(elem_operand); },
551 GetEvaluatedLiteralFor(operand));
552 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
553 break;
554 }
555 }
556
557 return Status::OK();
558 }
559
HandleReal(HloInstruction * real)560 Status HloEvaluator::HandleReal(HloInstruction* real) {
561 auto operand = real->operand(0);
562 switch (operand->shape().element_type()) {
563 case BF16: {
564 auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
565 real, [](bfloat16 elem_operand) { return elem_operand; },
566 GetEvaluatedLiteralFor(operand));
567 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
568 break;
569 }
570 case C64: {
571 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
572 real, [](complex64 elem_operand) { return std::real(elem_operand); },
573 GetEvaluatedLiteralFor(operand));
574 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
575 break;
576 }
577 case C128: {
578 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
579 real, [](complex128 elem_operand) { return std::real(elem_operand); },
580 GetEvaluatedLiteralFor(operand));
581 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
582 break;
583 }
584 case F16: {
585 auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
586 real, [](Eigen::half elem_operand) { return elem_operand; },
587 GetEvaluatedLiteralFor(operand));
588 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
589 break;
590 }
591 case F32: {
592 auto result_or = ElementWiseUnaryOpImpl<float, float>(
593 real, [](float elem_operand) { return elem_operand; },
594 GetEvaluatedLiteralFor(operand));
595 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
596 break;
597 }
598 case F64: {
599 auto result_or = ElementWiseUnaryOpImpl<double, double>(
600 real, [](double elem_operand) { return elem_operand; },
601 GetEvaluatedLiteralFor(operand));
602 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
603 break;
604 }
605 default:
606 LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
607 << PrimitiveType_Name(operand->shape().element_type());
608 }
609
610 return Status::OK();
611 }
612
HandleImag(HloInstruction * imag)613 Status HloEvaluator::HandleImag(HloInstruction* imag) {
614 auto operand = imag->operand(0);
615 switch (operand->shape().element_type()) {
616 case C64: {
617 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
618 imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
619 GetEvaluatedLiteralFor(imag->operand(0)));
620
621 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
622 break;
623 }
624 case C128: {
625 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
626 imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
627 GetEvaluatedLiteralFor(imag->operand(0)));
628
629 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
630 break;
631 }
632 default:
633 LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
634 << PrimitiveType_Name(operand->shape().element_type());
635 }
636
637 return Status::OK();
638 }
639
HandleComplex(HloInstruction * complex)640 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
641 const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
642 const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
643 TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
644
645 Literal result(complex->shape());
646 switch (complex->shape().element_type()) {
647 case C64: {
648 TF_RETURN_IF_ERROR(
649 result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
650 return std::complex<float>(real.Get<float>(multi_index),
651 imag.Get<float>(multi_index));
652 }));
653 break;
654 }
655 case C128: {
656 TF_RETURN_IF_ERROR(
657 result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
658 return std::complex<float>(real.Get<double>(multi_index),
659 imag.Get<double>(multi_index));
660 }));
661 break;
662 }
663 default:
664 LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
665 << PrimitiveType_Name(complex->shape().element_type());
666 }
667
668 evaluated_[complex] = std::move(result);
669 return Status::OK();
670 }
671
HandleCompare(HloInstruction * compare)672 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
673 ComparisonDirection direction = compare->comparison_direction();
674 auto lhs = compare->operand(0);
675 auto rhs = compare->operand(1);
676 DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
677 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
678
679 TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
680
681 const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
682 const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
683
684 // Note here we switch on the operand's type.
685 switch (lhs->shape().element_type()) {
686 case PRED: {
687 TF_ASSIGN_OR_RETURN(
688 evaluated_[compare],
689 Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
690 } break;
691 case U8: {
692 TF_ASSIGN_OR_RETURN(evaluated_[compare],
693 Compare<uint8>(compare->shape(), direction,
694 lhs_literal, rhs_literal));
695 } break;
696 case U16: {
697 TF_ASSIGN_OR_RETURN(evaluated_[compare],
698 Compare<uint16>(compare->shape(), direction,
699 lhs_literal, rhs_literal));
700 } break;
701 case U32: {
702 TF_ASSIGN_OR_RETURN(evaluated_[compare],
703 Compare<uint32>(compare->shape(), direction,
704 lhs_literal, rhs_literal));
705 } break;
706 case U64: {
707 TF_ASSIGN_OR_RETURN(evaluated_[compare],
708 Compare<uint64>(compare->shape(), direction,
709 lhs_literal, rhs_literal));
710 } break;
711 case S8: {
712 TF_ASSIGN_OR_RETURN(
713 evaluated_[compare],
714 Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
715 } break;
716 case S16: {
717 TF_ASSIGN_OR_RETURN(evaluated_[compare],
718 Compare<int16>(compare->shape(), direction,
719 lhs_literal, rhs_literal));
720 } break;
721 case S32: {
722 TF_ASSIGN_OR_RETURN(evaluated_[compare],
723 Compare<int32>(compare->shape(), direction,
724 lhs_literal, rhs_literal));
725 } break;
726 case S64: {
727 TF_ASSIGN_OR_RETURN(evaluated_[compare],
728 Compare<int64>(compare->shape(), direction,
729 lhs_literal, rhs_literal));
730 } break;
731 case F16: {
732 TF_ASSIGN_OR_RETURN(
733 evaluated_[compare],
734 Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
735 } break;
736 case BF16: {
737 TF_ASSIGN_OR_RETURN(evaluated_[compare],
738 Compare<bfloat16>(compare->shape(), direction,
739 lhs_literal, rhs_literal));
740 } break;
741 case F32: {
742 TF_ASSIGN_OR_RETURN(evaluated_[compare],
743 Compare<float>(compare->shape(), direction,
744 lhs_literal, rhs_literal));
745 } break;
746 case F64: {
747 TF_ASSIGN_OR_RETURN(evaluated_[compare],
748 Compare<double>(compare->shape(), direction,
749 lhs_literal, rhs_literal));
750 } break;
751 case C64: {
752 TF_ASSIGN_OR_RETURN(evaluated_[compare],
753 Compare<complex64>(compare->shape(), direction,
754 lhs_literal, rhs_literal));
755 } break;
756 case C128: {
757 TF_ASSIGN_OR_RETURN(evaluated_[compare],
758 Compare<complex128>(compare->shape(), direction,
759 lhs_literal, rhs_literal));
760 } break;
761 default:
762 LOG(FATAL) << "HandleCompare: unknown primitive type: "
763 << PrimitiveType_Name(lhs->shape().element_type());
764 }
765
766 return Status::OK();
767 }
768
HandleTuple(HloInstruction * tuple)769 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
770 std::vector<const Literal*> operand_literals;
771 for (auto operand : tuple->operands()) {
772 operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
773 }
774
775 evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
776 return Status::OK();
777 }
778
779 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
780 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)781 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
782 const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
783 int64 output_rank = output_shape.dimensions_size();
784 std::vector<int64> index_base(output_rank, 0);
785 std::vector<int64> index_count;
786 index_count.reserve(output_rank);
787 for (int64 i = 0; i < output_rank; i++) {
788 bool is_output_batch_dim =
789 !absl::c_binary_search(dim_numbers.offset_dims(), i);
790 index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
791 }
792
793 return {std::move(index_base), std::move(index_count),
794 std::vector<int64>(output_rank, 1)};
795 }
796
797 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
798 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64 output_rank,absl::Span<const int64> slice_sizes,const GatherDimensionNumbers & dim_numbers)799 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
800 int64 output_rank, absl::Span<const int64> slice_sizes,
801 const GatherDimensionNumbers& dim_numbers) {
802 std::vector<int64> index_base(output_rank, 0);
803 std::vector<int64> index_count(output_rank, 1);
804 int64 slice_sizes_idx = 0;
805 for (int64 i = 0; i < output_rank; i++) {
806 bool is_output_window_dim =
807 absl::c_binary_search(dim_numbers.offset_dims(), i);
808 if (is_output_window_dim) {
809 while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
810 slice_sizes_idx)) {
811 slice_sizes_idx++;
812 }
813 index_count[i] = slice_sizes[slice_sizes_idx++];
814 }
815 }
816
817 return {std::move(index_base), std::move(index_count),
818 std::vector<int64>(output_rank, 1)};
819 }
820
821 // This functor computes the contribution of start_indices to an input index
822 // corresponding to an output index. That is, given an output index I, it picks
823 // out the batch indices in I and uses them to look up a starting index, G, from
824 // the start indices tensor, and expands G into the input space according to
825 // start_index_map.
826 class OutputBatchIndexToInputIndex {
827 public:
828 // The constructor does some setup work that is amortized across all
829 // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)830 explicit OutputBatchIndexToInputIndex(
831 const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
832 const Shape& output_shape, const Literal* start_indices)
833 : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
834 for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
835 output_dim_is_batch_dims_.push_back(
836 !absl::c_binary_search(dim_numbers_.offset_dims(), i));
837 }
838
839 for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
840 int64 index_of_input_dim_in_index_vector =
841 std::distance(dim_numbers_.start_index_map().begin(),
842 absl::c_find(dim_numbers_.start_index_map(), i));
843 if (index_of_input_dim_in_index_vector ==
844 dim_numbers_.start_index_map_size()) {
845 input_dim_value_to_index_vector_.push_back(-1);
846 } else {
847 input_dim_value_to_index_vector_.push_back(
848 index_of_input_dim_in_index_vector);
849 }
850 }
851
852 index_vector_index_.resize(start_indices_.shape().dimensions_size());
853 input_index_.resize(input_shape.dimensions_size());
854 int64 index_vector_size =
855 start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
856 index_vector_.resize(index_vector_size);
857 }
858
859 // Returns the contribution of start_indices to the input index corresponding
860 // to output_index. See gather_inner_loop_body.
861 //
862 // This is conceptually a stateless transformation from output_index to the
863 // gather input index, but:
864 //
865 // - Instead of allocating memory to represent the gather input index on
866 // every invocation we reuse the same storage for the result
867 // (input_index_), mutating it in place.
868 // - Instead of allocating buffers for temporary values like
869 // index_vector_index_ and index_vector on every invocation, we reuse the
870 // same storage for all invocations.
871 //
872 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)873 StatusOr<absl::Span<const int64>> operator()(
874 absl::Span<const int64> output_index) {
875 PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
876 TF_RETURN_IF_ERROR(FetchIndexVector());
877 PropagateIndexVectorToInputIndex();
878 return absl::Span<const int64>(input_index_);
879 }
880
881 private:
882 // Propagates the batch dimensions from the output index into
883 // index_vector_index_ by mutating index_vector_index_ in place. Does not
884 // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
885 // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64> output_index)886 void PropagateOutputIndexGatherDimsToIndexVectorIndex(
887 absl::Span<const int64> output_index) {
888 int64 index_vector_index_i = 0;
889 for (int64 i = 0, e = output_index.size(); i < e; i++) {
890 if (!output_dim_is_batch_dims_[i]) {
891 continue;
892 }
893
894 if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
895 index_vector_index_i++;
896 }
897
898 index_vector_index_[index_vector_index_i++] = output_index[i];
899 }
900 }
901
902 // Populates index_vector_ by iterating over start_indices_ according to
903 // index_vector_index_.
FetchIndexVector()904 Status FetchIndexVector() {
905 int64 index_vector_dim = dim_numbers_.index_vector_dim();
906 for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
907 index_vector_index_[index_vector_dim] = i;
908 TF_ASSIGN_OR_RETURN(index_vector_[i],
909 start_indices_.GetIntegralAsS64(index_vector_index_));
910 }
911 return Status::OK();
912 }
913
914 // Populates input_index_.
PropagateIndexVectorToInputIndex()915 void PropagateIndexVectorToInputIndex() {
916 for (int64 i = 0, e = input_index_.size(); i < e; i++) {
917 if (input_dim_value_to_index_vector_[i] != -1) {
918 input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
919 }
920
921 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
922 // remains 0, as set by the constructor.
923 }
924 }
925
926 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
927 // the input index from the index vector. See
928 // PropagateIndexVectorToInputIndex.
929 std::vector<int64> input_dim_value_to_index_vector_;
930
931 // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
932 // dimension.
933 std::vector<bool> output_dim_is_batch_dims_;
934
935 // The buffer into which we construct an index into start_indices_ to fetch
936 // the index vector.
937 std::vector<int64> index_vector_index_;
938
939 // The index vector fetched from start_indices_.
940 std::vector<int64> index_vector_;
941
942 // The result computed by this functor. operator() returns a Span into
943 // this vector.
944 std::vector<int64> input_index_;
945
946 const GatherDimensionNumbers& dim_numbers_;
947 const Literal& start_indices_;
948 };
949
950 // This functor computes the contribution of the offset indices in an output
951 // index to an input index. That is, given an output index I it picks out the
952 // output offset indices in I and expands it into an index into the input shape.
953 class OutputOffsetIndexToInputIndex {
954 public:
955 // The constructor does some setup work that is amortized across all
956 // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)957 explicit OutputOffsetIndexToInputIndex(
958 const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
959 const Shape& output_shape) {
960 std::vector<int64> window_index_to_output_index;
961 int64 output_index_count = 0;
962 for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
963 if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
964 window_index_to_output_index.push_back(output_index_count++);
965 } else {
966 output_index_count++;
967 }
968 }
969
970 int64 window_dim_count = 0;
971 for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
972 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
973 input_dim_value_to_output_index_.push_back(-1);
974 } else {
975 input_dim_value_to_output_index_.push_back(
976 window_index_to_output_index[window_dim_count++]);
977 }
978 }
979
980 input_index_.resize(input_shape.dimensions_size());
981 }
982
983 // Returns the contribution of the window indices to the input index
984 // corresponding to output_index. See gather_inner_loop_body.
985 //
986 // This is conceptually a stateless transformation from output_index to the
987 // window input index, but instead of allocating memory to represent the
988 // gather input index on every invocation we reuse the same storage for the
989 // result (input_index_), mutating it in place.
990 //
991 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)992 StatusOr<absl::Span<const int64>> operator()(
993 absl::Span<const int64> output_index) {
994 PropagateOutputIndexWindowDimsToInputIndex(output_index);
995 return absl::Span<const int64>(input_index_);
996 }
997
998 // Returns for a given 'input_dim' the corresponding output dimension index,
999 // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64 input_dim)1000 int64 input_dim_value_to_output_index(int64 input_dim) {
1001 return input_dim_value_to_output_index_[input_dim];
1002 }
1003
1004 private:
1005 // Propagates window dimensions from the output index to input_index_ by
1006 // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64> output_index)1007 void PropagateOutputIndexWindowDimsToInputIndex(
1008 absl::Span<const int64> output_index) {
1009 for (int64 i = 0, e = input_index_.size(); i < e; i++) {
1010 if (input_dim_value_to_output_index_[i] != -1) {
1011 input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
1012 }
1013
1014 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1015 // remains 0, as set by the constructor.
1016 }
1017 }
1018
1019 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1020 // the input index from the output index. See
1021 // PropagateOutputIndexWindowDimsToInputIndex.
1022 std::vector<int64> input_dim_value_to_output_index_;
1023
1024 // The result computed by this functor. operator() returns a Span into
1025 // this vector.
1026 std::vector<int64> input_index_;
1027 };
1028
1029 // Rehapes the gather indices input to have a trailing degenerate `1` dimension
1030 // if necessary. Hands over the ownership of the newly created literal (if
1031 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64 index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)1032 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
1033 int64 index_vector_dim, const Literal& start_indices,
1034 Literal* reshaped_start_indices) {
1035 if (start_indices.shape().dimensions_size() != index_vector_dim) {
1036 return std::cref(start_indices);
1037 }
1038
1039 std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
1040 start_indices.shape().dimensions().end());
1041 new_shape.push_back(1);
1042 TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
1043 start_indices.Reshape(new_shape));
1044 return std::cref(*reshaped_start_indices);
1045 }
1046
HandleGather(HloInstruction * gather)1047 Status HloEvaluator::HandleGather(HloInstruction* gather) {
1048 Literal result = Literal::CreateFromShape(gather->shape());
1049 const Shape& shape = gather->shape();
1050 const GatherDimensionNumbers& dim_numbers =
1051 gather->gather_dimension_numbers();
1052 const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
1053 Literal reshaped_start_indices;
1054 TF_ASSIGN_OR_RETURN(
1055 const Literal& start_indices,
1056 ReshapedGatherIndices(dim_numbers.index_vector_dim(),
1057 GetEvaluatedLiteralFor(gather->operand(1)),
1058 &reshaped_start_indices));
1059
1060 // We iterate over the gather dimensions in the output shape in an outer loop
1061 // nest, and iterate over the window dimensions in the output shape in an
1062 // inner loop nest.
1063
1064 ShapeUtil::IndexIterationSpace start_indices_iteration_space =
1065 IterationSpaceForOutputBatchIndices(shape, dim_numbers);
1066 ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
1067 IterationSpaceForOutputOffsetIndices(
1068 shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
1069
1070 // Scratch buffers that hold an index in the output shape and the
1071 // corresponding index in the input shape.
1072 std::vector<int64> input_index(operand.shape().dimensions_size());
1073 std::vector<int64> output_index(gather->shape().dimensions_size());
1074 std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
1075
1076 OutputBatchIndexToInputIndex output_batch_index_to_input_index(
1077 &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1078 /*output_shape=*/shape, &start_indices);
1079 OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
1080 gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1081 /*output_shape=*/shape);
1082
1083 const Shape& operand_shape = operand.shape();
1084
1085 auto gather_inner_loop_body =
1086 [&](absl::Span<const int64> output_window_index,
1087 absl::Span<const int64> input_gather_index,
1088 absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1089 TF_ASSIGN_OR_RETURN(
1090 absl::Span<const int64> input_window_index,
1091 output_offset_index_to_input_index(output_window_index));
1092 for (int i = 0, e = output_index.size(); i < e; i++) {
1093 output_index[i] = output_gather_index[i] + output_window_index[i];
1094 DCHECK_LT(output_index[i], shape.dimensions(i));
1095 }
1096 for (int i = 0, e = input_gather_index.size(); i < e; i++) {
1097 int64 output_dim =
1098 output_offset_index_to_input_index.input_dim_value_to_output_index(i);
1099 // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
1100 // we set the iteration index to 0, so for the purpose of the following
1101 // calculations we can consider the output dimension size to be 1.
1102 int64 output_dim_size =
1103 output_dim == -1 ? 1 : shape.dimensions(output_dim);
1104 // Clamp the gather index so that the gather region fits in the operand.
1105 // input_index_clamped[i] = clamp(input_gather_index[i], 0,
1106 // operand_shape.dimensions(i) -
1107 // output_dim_size);
1108 input_index_clamped[i] =
1109 std::min(operand_shape.dimensions(i) - output_dim_size,
1110 std::max(0LL, input_gather_index[i]));
1111 }
1112 for (int i = 0, e = input_index.size(); i < e; i++) {
1113 input_index[i] = input_index_clamped[i] + input_window_index[i];
1114 DCHECK_GE(input_index[i], 0);
1115 DCHECK_LT(input_index[i], operand_shape.dimensions(i));
1116 }
1117 TF_RETURN_IF_ERROR(
1118 result.CopyElementFrom(operand, input_index, output_index));
1119 return true;
1120 };
1121
1122 auto gather_outer_loop_body =
1123 [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1124 TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
1125 output_batch_index_to_input_index(output_gather_index));
1126 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1127 shape, offset_indices_iteration_space,
1128 std::bind(gather_inner_loop_body, std::placeholders::_1,
1129 input_gather_index, output_gather_index)));
1130 return true;
1131 };
1132
1133 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1134 shape, start_indices_iteration_space, gather_outer_loop_body));
1135 evaluated_[gather] = std::move(result);
1136 return Status::OK();
1137 }
1138
HandleBroadcast(HloInstruction * broadcast)1139 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
1140 const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
1141
1142 TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
1143 << "broadcast dimensions is of size: " << broadcast->dimensions().size()
1144 << " and rank of operand_to_broadcast is: " << operand.shape().rank();
1145 // Checks that operand's dimensions are the same as the broadcast's
1146 // dimensions along the dimensions to be broadcasted.
1147 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
1148 auto operand_dim_size = operand.shape().dimensions(i);
1149 auto broadcast_dim_size =
1150 broadcast->shape().dimensions(broadcast->dimensions(i));
1151 TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
1152 "Operand dimension %d is broadcast to output dimension %d, but the "
1153 "sizes of these two dims do not match (%d vs %d): %s",
1154 i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
1155 broadcast->ToString());
1156 }
1157
1158 TF_ASSIGN_OR_RETURN(
1159 evaluated_[broadcast],
1160 operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
1161
1162 return Status::OK();
1163 }
1164
HandleAfterAll(HloInstruction * after_all)1165 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
1166 evaluated_[after_all] = LiteralUtil::CreateToken();
1167 return Status::OK();
1168 }
1169
HandleAddDependency(HloInstruction * add_dependency)1170 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
1171 // AddDedendency just forwards its zero-th operand.
1172 evaluated_[add_dependency] =
1173 GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
1174 return Status::OK();
1175 }
1176
HandleGetTupleElement(HloInstruction * get_tuple_element)1177 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
1178 const auto result_shape = get_tuple_element->shape();
1179 const int64 index = get_tuple_element->tuple_index();
1180
1181 auto operand = get_tuple_element->operand(0);
1182 TF_ASSIGN_OR_RETURN(
1183 auto inferred_return_shape,
1184 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
1185 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1186 << "return shape set to: " << ShapeUtil::HumanString(result_shape)
1187 << " but is inferred to be: "
1188 << ShapeUtil::HumanString(inferred_return_shape);
1189
1190 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1191
1192 evaluated_[get_tuple_element] =
1193 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
1194 return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
1195 /*dest_shape_index=*/{},
1196 /*src_shape_index=*/{index});
1197 }
1198
HandleCopy(HloInstruction * copy)1199 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
1200 TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
1201 evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
1202 return Status::OK();
1203 }
1204
HandleCall(HloInstruction * call)1205 Status HloEvaluator::HandleCall(HloInstruction* call) {
1206 auto* computation = call->to_apply();
1207 auto operands = call->operands();
1208
1209 std::vector<const Literal*> arg_literals;
1210 arg_literals.reserve(operands.size());
1211 for (auto operand : operands) {
1212 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1213 arg_literals.push_back(&arg_literal);
1214 }
1215
1216 HloEvaluator embedded_evaluator;
1217 embedded_evaluator.set_dynamic_dimension_inference(
1218 dynamic_dimension_inference_);
1219 TF_ASSIGN_OR_RETURN(Literal result,
1220 embedded_evaluator.Evaluate(*computation, arg_literals));
1221
1222 evaluated_[call] = std::move(result);
1223 return Status::OK();
1224 }
1225
HandleFusion(HloInstruction * fusion)1226 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
1227 HloModuleConfig config;
1228 // Attach cloned computation to an empty HLO module so the existing ones are
1229 // not modified.
1230 HloModule empty_hlo_module("EmptyModuleForFusion", config);
1231 HloCloneContext context(&empty_hlo_module);
1232 auto cloned_fused_computation =
1233 fusion->fused_instructions_computation()->Clone(
1234 /*suffix=*/"clone_with_layout", &context);
1235 for (auto* instruction : cloned_fused_computation->instructions()) {
1236 if (!LayoutUtil::HasLayout(instruction->shape())) {
1237 LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
1238 }
1239 }
1240 auto readded_computation =
1241 empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
1242
1243 auto operands = fusion->operands();
1244 std::vector<const Literal*> arg_literals;
1245 arg_literals.reserve(operands.size());
1246 for (auto operand : operands) {
1247 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1248 arg_literals.push_back(&arg_literal);
1249 }
1250
1251 HloEvaluator embedded_evaluator;
1252 embedded_evaluator.set_dynamic_dimension_inference(
1253 dynamic_dimension_inference_);
1254 TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
1255 *readded_computation, arg_literals));
1256
1257 evaluated_[fusion] = std::move(result);
1258 return Status::OK();
1259 }
1260
HandleConditional(HloInstruction * conditional)1261 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
1262 const auto& branch_index_literal =
1263 GetEvaluatedLiteralFor(conditional->operand(0));
1264 int branch_index;
1265 if (conditional->operand(0)->shape().element_type() == PRED) {
1266 branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
1267 } else {
1268 branch_index = branch_index_literal.Get<int32>({});
1269 if (branch_index < 0 || branch_index >= conditional->branch_count()) {
1270 branch_index = conditional->branch_count() - 1;
1271 }
1272 }
1273 const auto& branch_computation_arg =
1274 GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
1275
1276 HloEvaluator embedded_evaluator;
1277 embedded_evaluator.set_dynamic_dimension_inference(
1278 dynamic_dimension_inference_);
1279 TF_ASSIGN_OR_RETURN(Literal result,
1280 embedded_evaluator.Evaluate(
1281 *conditional->branch_computation(branch_index),
1282 {&branch_computation_arg}));
1283
1284 evaluated_[conditional] = std::move(result);
1285 return Status::OK();
1286 }
1287
HandleSelect(HloInstruction * select)1288 Status HloEvaluator::HandleSelect(HloInstruction* select) {
1289 const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
1290 const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
1291 const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
1292
1293 // If predicate is of scalar type, no element-wise selection would be needed.
1294 if (ShapeUtil::IsScalar(pred.shape())) {
1295 if (pred.Get<bool>({})) {
1296 evaluated_[select] = on_true.Clone();
1297 } else {
1298 evaluated_[select] = on_false.Clone();
1299 }
1300 return Status::OK();
1301 }
1302
1303 return DefaultAction(select);
1304 }
1305
HandleTupleSelect(HloInstruction * tuple_select)1306 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
1307 const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
1308 const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
1309 const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
1310
1311 if (pred.Get<bool>({})) {
1312 evaluated_[tuple_select] = on_true.Clone();
1313 } else {
1314 evaluated_[tuple_select] = on_false.Clone();
1315 }
1316 return Status::OK();
1317 }
1318
HandleWhile(HloInstruction * while_hlo)1319 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
1320 HloComputation* cond_comp = while_hlo->while_condition();
1321 HloComputation* body_comp = while_hlo->while_body();
1322 // Initialize the loop carried valued with the input to the While instruction.
1323 auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
1324 bool keep_going = true;
1325 int64 iteration_count = 0;
1326 HloEvaluator cond_evaluator(max_loop_iterations_);
1327 cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
1328 HloEvaluator loop_body_evaluator(max_loop_iterations_);
1329 loop_body_evaluator.set_dynamic_dimension_inference(
1330 dynamic_dimension_inference_);
1331 while (keep_going) {
1332 if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
1333 return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
1334 while_hlo->name(), max_loop_iterations_);
1335 }
1336 TF_ASSIGN_OR_RETURN(auto cond_val,
1337 cond_evaluator.Evaluate(*cond_comp, {&lcv}));
1338 keep_going = cond_val.GetFirstElement<bool>();
1339 if (keep_going) {
1340 TF_ASSIGN_OR_RETURN(auto body_val,
1341 loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
1342 VLOG(3) << "Loop iteration result: " << body_val.ToString();
1343 lcv = std::move(body_val);
1344 cond_evaluator.ResetVisitStates();
1345 loop_body_evaluator.ResetVisitStates();
1346 }
1347 }
1348 evaluated_[while_hlo] = std::move(lcv);
1349 return Status::OK();
1350 }
1351
1352 namespace {
1353 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar)1354 Literal ExtractLiteralFromIndexPositions(const Literal& from,
1355 absl::Span<int64 const> indices,
1356 bool extract_as_scalar) {
1357 if (extract_as_scalar) {
1358 return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
1359 }
1360 // We use a InlinedVector here because we need to convert it to an
1361 // absl::Span later, and this would not work with std::vector<bool>.
1362 absl::InlinedVector<NativeT, 10> values;
1363 for (int64 index : indices) {
1364 values.push_back(from.Get<NativeT>({index}));
1365 }
1366 return LiteralUtil::CreateR1<NativeT>(values);
1367 }
1368
ExtractFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar=false)1369 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
1370 absl::Span<int64 const> indices,
1371 bool extract_as_scalar = false) {
1372 if (extract_as_scalar) {
1373 CHECK_EQ(indices.size(), 1);
1374 }
1375 PrimitiveType type = from.shape().element_type();
1376 switch (type) {
1377 case PRED: {
1378 return ExtractLiteralFromIndexPositions<bool>(from, indices,
1379 extract_as_scalar);
1380 }
1381 case U8: {
1382 return ExtractLiteralFromIndexPositions<uint8>(from, indices,
1383 extract_as_scalar);
1384 }
1385 case S8: {
1386 return ExtractLiteralFromIndexPositions<int8>(from, indices,
1387 extract_as_scalar);
1388 }
1389 case BF16: {
1390 return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
1391 extract_as_scalar);
1392 }
1393 case F16: {
1394 return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
1395 extract_as_scalar);
1396 }
1397 case U16: {
1398 return ExtractLiteralFromIndexPositions<uint16>(from, indices,
1399 extract_as_scalar);
1400 }
1401 case S16: {
1402 return ExtractLiteralFromIndexPositions<int16>(from, indices,
1403 extract_as_scalar);
1404 }
1405 case F32: {
1406 return ExtractLiteralFromIndexPositions<float>(from, indices,
1407 extract_as_scalar);
1408 }
1409 case U32: {
1410 return ExtractLiteralFromIndexPositions<uint32>(from, indices,
1411 extract_as_scalar);
1412 }
1413 case S32: {
1414 return ExtractLiteralFromIndexPositions<int32>(from, indices,
1415 extract_as_scalar);
1416 }
1417 case F64: {
1418 return ExtractLiteralFromIndexPositions<double>(from, indices,
1419 extract_as_scalar);
1420 }
1421 case U64: {
1422 return ExtractLiteralFromIndexPositions<uint64>(from, indices,
1423 extract_as_scalar);
1424 }
1425 case S64: {
1426 return ExtractLiteralFromIndexPositions<int64>(from, indices,
1427 extract_as_scalar);
1428 }
1429 default:
1430 return InvalidArgument("Unsupported type for Sort: %s",
1431 PrimitiveType_Name(type));
1432 }
1433 }
1434 } // namespace
1435
HandleSort(HloInstruction * sort)1436 Status HloEvaluator::HandleSort(HloInstruction* sort) {
1437 TF_RET_CHECK(sort->operand_count() >= 1)
1438 << "Expected at least 1 operand for sort";
1439 for (int64 i = 1; i < sort->operand_count(); ++i) {
1440 TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
1441 sort->operand(i)->shape()))
1442 << "All Sort operands must have the same dimensions";
1443 }
1444
1445 if (VLOG_IS_ON(3)) {
1446 for (int64 i = 0; i < sort->operand_count(); ++i) {
1447 VLOG(3) << "HandleSort operand " << i << " literal: "
1448 << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
1449 }
1450 }
1451 Shape key_shape = sort->operand(0)->shape();
1452 auto rank = key_shape.rank();
1453 std::vector<Literal> result_literals;
1454 result_literals.reserve(sort->operand_count());
1455 for (int64 i = 0; i < sort->operand_count(); ++i) {
1456 result_literals.emplace_back(sort->operand(i)->shape());
1457 }
1458 std::vector<int64> zero_base(rank, 0);
1459 std::vector<int64> increment(rank, 1);
1460 int64 sort_dim = sort->dimensions(0);
1461 int64 sort_dim_elements = key_shape.dimensions(sort_dim);
1462 increment[sort_dim] = sort_dim_elements;
1463 HloEvaluator embedded_evaluator(max_loop_iterations_);
1464 // Iterate through each dimension except 'sort_dim'.
1465 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1466 key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
1467 [&](absl::Span<const int64> indices) -> StatusOr<bool> {
1468 // Extract a slice from each operand literal that corresponds to
1469 // exactly the row in dimension 'sort_dim'.
1470 std::vector<int64> limit_indices(indices.begin(), indices.end());
1471 absl::c_for_each(limit_indices, [](int64& index) { ++index; });
1472 limit_indices[sort_dim] = sort_dim_elements;
1473 std::vector<Literal> literals_to_sort;
1474 literals_to_sort.reserve(sort->operand_count());
1475 for (int64 i = 0; i < sort->operand_count(); ++i) {
1476 TF_ASSIGN_OR_RETURN(auto literal_to_sort,
1477 GetEvaluatedLiteralFor(sort->operand(i))
1478 .Slice(indices, limit_indices)
1479 .Reshape({sort_dim_elements}));
1480 literals_to_sort.push_back(std::move(literal_to_sort));
1481 }
1482 std::vector<int64> indices_to_sort(sort_dim_elements);
1483 std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
1484 Status compare_status = Status::OK();
1485 auto comparator = [sort, &compare_status, &embedded_evaluator,
1486 &literals_to_sort](int64 a, int64 b) {
1487 std::vector<Literal> literals;
1488 literals.reserve(2 * sort->operand_count());
1489 for (int64 i = 0; i < sort->operand_count(); ++i) {
1490 auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
1491 /*extract_as_scalar=*/true);
1492 if (!lhs.ok()) {
1493 compare_status = lhs.status();
1494 return false;
1495 }
1496 literals.push_back(std::move(lhs.ValueOrDie()));
1497 auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
1498 /*extract_as_scalar=*/true);
1499 if (!rhs.ok()) {
1500 compare_status = rhs.status();
1501 return false;
1502 }
1503 literals.push_back(std::move(rhs.ValueOrDie()));
1504 }
1505 std::vector<const Literal*> literal_ptrs;
1506 absl::c_transform(literals, std::back_inserter(literal_ptrs),
1507 [](const Literal& literal) { return &literal; });
1508
1509 auto computed_result =
1510 embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
1511 // Clear visit states so that we can use the evaluator again
1512 // on the same computation.
1513 embedded_evaluator.ResetVisitStates();
1514 if (!computed_result.ok()) {
1515 compare_status = computed_result.status();
1516 return false;
1517 }
1518 return computed_result.ValueOrDie().Get<bool>({});
1519 };
1520 if (Cast<HloSortInstruction>(sort)->is_stable()) {
1521 std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
1522 comparator);
1523 } else {
1524 std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
1525 }
1526 if (!compare_status.ok()) {
1527 return compare_status;
1528 }
1529 std::vector<int64> slice_dimensions(rank, 1);
1530 slice_dimensions[sort_dim] = sort_dim_elements;
1531 std::vector<int64> start_indices(rank, 0);
1532 for (int64 i = 0; i < sort->operand_count(); ++i) {
1533 TF_ASSIGN_OR_RETURN(
1534 Literal sorted_literal,
1535 ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
1536 TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
1537 sorted_literal.Reshape(slice_dimensions));
1538 TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
1539 sorted_literal_reshaped, start_indices, indices,
1540 slice_dimensions));
1541 }
1542 return true;
1543 }));
1544
1545 if (sort->operand_count() == 1) {
1546 evaluated_[sort] = std::move(result_literals[0]);
1547 } else {
1548 std::vector<const Literal*> literal_ptrs;
1549 absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
1550 [](const Literal& literal) { return &literal; });
1551
1552 Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
1553 VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
1554
1555 evaluated_[sort] = std::move(result_tuple);
1556 }
1557 return Status::OK();
1558 }
1559
HandleReduce(HloInstruction * reduce)1560 Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
1561 if (!reduce->shape().IsTuple()) {
1562 return DefaultAction(reduce);
1563 } else {
1564 auto first_element_type = reduce->shape().tuple_shapes(0).element_type();
1565 for (const auto& tuple_shape : reduce->shape().tuple_shapes()) {
1566 if (tuple_shape.element_type() != first_element_type) {
1567 return Unimplemented(
1568 "Reduce with several outputs that have mixed element types is "
1569 "unsupported");
1570 }
1571 }
1572 return reduce->Visit(typed_visitors_[first_element_type].get());
1573 }
1574 }
1575
HandleCustomCall(HloInstruction * custom_call)1576 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
1577 if (!custom_call_handler_) {
1578 // No handler is registered; this means custom-calls are not allowed.
1579 return DefaultAction(custom_call);
1580 }
1581
1582 // Evaluate input operands so the handler has access to the operand data.
1583 std::vector<const Literal*> operands;
1584 operands.reserve(custom_call->operand_count());
1585 for (const HloInstruction* operand : custom_call->operands()) {
1586 operands.push_back(&GetEvaluatedLiteralFor(operand));
1587 }
1588
1589 // Synchronously issue the handler to populate the instruction output literal.
1590 TF_ASSIGN_OR_RETURN(
1591 auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
1592
1593 evaluated_[custom_call] = std::move(output);
1594 return Status::OK();
1595 }
1596
Preprocess(HloInstruction * hlo)1597 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
1598 VLOG(2) << "About to visit HLO: " << hlo->ToString();
1599 return ShapeUtil::ValidateShape(hlo->shape());
1600 }
1601
Postprocess(HloInstruction * hlo)1602 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
1603 VLOG(2) << "Finished visiting " << hlo->ToString()
1604 << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
1605 // Out of convenience the literal may have been produced with a different
1606 // layout. Relayout as indicated by the HLO instruction.
1607 if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
1608 hlo->shape())) {
1609 evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
1610 }
1611 return Status::OK();
1612 }
1613
1614 namespace {
1615 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64 m,int64 n,int64 k,int32 transpose_lhs,int32 transpose_rhs)> & impl_fn)1616 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
1617 const Array2D<T>& lhs, const Array2D<T>& rhs,
1618 const std::function<void(
1619 const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n,
1620 int64 k, int32 transpose_lhs, int32 transpose_rhs)>& impl_fn) {
1621 CHECK_EQ(lhs.width(), rhs.height());
1622 int m = lhs.height();
1623 int n = rhs.width();
1624 int k = lhs.width();
1625 auto result = absl::make_unique<Array2D<T>>(m, n);
1626 // Because Eigen is a header-oriented library, make sure that the Eigen code
1627 // is the same as the code used by the CPU backend (otherwise the linker will
1628 // randomly pick *some* definition).
1629 impl_fn(
1630 /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
1631 k,
1632 /*transpose_lhs=*/0,
1633 /*transpose_rhs=*/0);
1634 return result;
1635 }
1636 } // namespace
1637
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)1638 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
1639 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
1640 return MatmulArray2DImpl<Eigen::half>(
1641 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
1642 }
1643
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)1644 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
1645 const Array2D<float>& lhs, const Array2D<float>& rhs) {
1646 return MatmulArray2DImpl<float>(
1647 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
1648 }
1649
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)1650 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
1651 const Array2D<double>& lhs, const Array2D<double>& rhs) {
1652 return MatmulArray2DImpl<double>(
1653 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
1654 }
1655
1656 } // namespace xla
1657