1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <complex>
20 #include <cstdlib>
21 #include <functional>
22 #include <iterator>
23 #include <string>
24 #include <type_traits>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/index_util.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/hlo_query.h"
43 #include "tensorflow/compiler/xla/service/shape_inference.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/stream_executor/lib/statusor.h"
56
57 namespace xla {
58
59 namespace {
60
61 template <typename OperandT>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)62 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
63 LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
64 std::function<bool(OperandT, OperandT)> compare_op;
65 switch (direction) {
66 case ComparisonDirection::kEq:
67 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
68 return lhs_el == rhs_el;
69 };
70 break;
71 case ComparisonDirection::kNe:
72 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
73 return lhs_el != rhs_el;
74 };
75 break;
76 case ComparisonDirection::kGe:
77 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
78 return lhs_el >= rhs_el;
79 };
80 break;
81 case ComparisonDirection::kGt:
82 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
83 return lhs_el > rhs_el;
84 };
85 break;
86 case ComparisonDirection::kLe:
87 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
88 return lhs_el <= rhs_el;
89 };
90 break;
91 case ComparisonDirection::kLt:
92 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
93 return lhs_el < rhs_el;
94 };
95 break;
96 }
97
98 Literal result(shape);
99 TF_RETURN_IF_ERROR(
100 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
101 return compare_op(lhs_literal.Get<OperandT>(multi_index),
102 rhs_literal.Get<OperandT>(multi_index));
103 }));
104
105 return std::move(result);
106 }
107
108 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)109 StatusOr<Literal> Compare<complex64>(const Shape& shape,
110 ComparisonDirection direction,
111 LiteralSlice lhs_literal,
112 LiteralSlice rhs_literal) {
113 std::function<bool(complex64, complex64)> compare_op;
114 switch (direction) {
115 case ComparisonDirection::kEq:
116 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
117 return lhs_el == rhs_el;
118 };
119 break;
120 case ComparisonDirection::kNe:
121 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
122 return lhs_el != rhs_el;
123 };
124 break;
125 default:
126 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
127 << ComparisonDirectionToString(direction);
128 }
129
130 Literal result(shape);
131 TF_RETURN_IF_ERROR(
132 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
133 return compare_op(lhs_literal.Get<complex64>(multi_index),
134 rhs_literal.Get<complex64>(multi_index));
135 }));
136
137 return std::move(result);
138 }
139
140 template <>
Compare(const Shape & shape,ComparisonDirection direction,LiteralSlice lhs_literal,LiteralSlice rhs_literal)141 StatusOr<Literal> Compare<complex128>(const Shape& shape,
142 ComparisonDirection direction,
143 LiteralSlice lhs_literal,
144 LiteralSlice rhs_literal) {
145 std::function<bool(complex128, complex128)> compare_op;
146 switch (direction) {
147 case ComparisonDirection::kEq:
148 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
149 return lhs_el == rhs_el;
150 };
151 break;
152 case ComparisonDirection::kNe:
153 compare_op = [](complex128 lhs_el, complex128 rhs_el) {
154 return lhs_el != rhs_el;
155 };
156 break;
157 default:
158 LOG(FATAL) << "unhandled direction for conversion to Comparison: "
159 << ComparisonDirectionToString(direction);
160 }
161
162 Literal result(shape);
163 TF_RETURN_IF_ERROR(
164 result.Populate<bool>([&](absl::Span<const int64> multi_index) {
165 return compare_op(lhs_literal.Get<complex128>(multi_index),
166 rhs_literal.Get<complex128>(multi_index));
167 }));
168
169 return std::move(result);
170 }
171
172 } // namespace
173
174 // Note that unsupported types by the typed visitor does not necessarily imply
175 // the non-typed HloEvaluator (parent evaluator) would not support them either
176 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
177 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
178 // HloEvaluatorTypedVisitor cannot.
HloEvaluator(int64_t max_loop_iterations)179 HloEvaluator::HloEvaluator(int64_t max_loop_iterations)
180 : max_loop_iterations_(max_loop_iterations) {
181 typed_visitors_[PRED] =
182 absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
183 typed_visitors_[U8] =
184 absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
185 typed_visitors_[U16] =
186 absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
187 typed_visitors_[U32] =
188 absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
189 typed_visitors_[U64] =
190 absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
191 typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
192 typed_visitors_[S16] =
193 absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
194 typed_visitors_[S32] =
195 absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
196 typed_visitors_[S64] =
197 absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
198 typed_visitors_[F16] =
199 absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
200 typed_visitors_[F32] =
201 absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
202 typed_visitors_[F64] =
203 absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
204 typed_visitors_[C64] =
205 absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
206 typed_visitors_[C128] =
207 absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
208
209 // Most of the evaluator computations we use don't support BF16 (e.g.,
210 // std::ceil, std::tanh). To make evaluator work with BF16, we set all
211 // elementwise computations to be done in F32 and do BF16<->F32 conversion
212 // around the input and the output of the computations.
213 typed_visitors_[BF16] =
214 absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
215
216 typed_visitors_[TUPLE] =
217 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
218 return Unimplemented(
219 "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
220 });
221 typed_visitors_[OPAQUE_TYPE] =
222 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
223 return Unimplemented(
224 "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE.");
225 });
226 typed_visitors_[TOKEN] =
227 absl::make_unique<FunctionVisitor>([](HloInstruction*) {
228 return Unimplemented(
229 "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
230 });
231 }
232
Evaluate(const HloComputation & computation,absl::Span<const Literal * const> arg_literals)233 StatusOr<Literal> HloEvaluator::Evaluate(
234 const HloComputation& computation,
235 absl::Span<const Literal* const> arg_literals) {
236 CHECK(computation.parent() != nullptr);
237 XLA_VLOG_LINES(
238 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
239
240 if (arg_literals.size() != computation.num_parameters()) {
241 return InvalidArgument(
242 "Expected %d argument%s, but got %d.", computation.num_parameters(),
243 computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
244 }
245 for (int64_t i = 0; i < arg_literals.size(); ++i) {
246 const auto& computation_shape =
247 computation.parameter_instruction(i)->shape();
248 const auto& arg_shape = arg_literals[i]->shape();
249 if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
250 arg_shape)) {
251 return InvalidArgument(
252 "Shape mismatch at parameter %d. Computation expected %s, but arg "
253 "was %s.",
254 i, ShapeUtil::HumanStringWithLayout(computation_shape),
255 ShapeUtil::HumanStringWithLayout(arg_shape));
256 }
257 }
258
259 evaluated_.clear();
260 arg_literals_.clear();
261 for (const auto& literal_ptr : arg_literals) {
262 arg_literals_.push_back(&*literal_ptr);
263 }
264
265 // Re-seed RNG, either from the configuration's seed or a monotonic
266 // per-evaluator seed (which prevents two evaluators from returning the same
267 // random sequence).
268 if (computation.parent()->config().seed()) {
269 seed_ = computation.parent()->config().seed();
270 } else {
271 // Start global_seed at a (true) random value.
272 static std::atomic<uint64> global_seed{std::random_device()()};
273 seed_ = global_seed.fetch_add(1);
274 }
275 engine_.seed(seed_);
276
277 TF_RETURN_IF_ERROR(computation.Accept(this));
278
279 if (VLOG_IS_ON(100)) {
280 for (const HloInstruction* instr : computation.instructions()) {
281 VLOG(100) << instr->name() << " = " << GetEvaluatedLiteralFor(instr);
282 }
283 }
284
285 return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
286 }
287
Evaluate(HloInstruction * instruction)288 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
289 // If the instruction is a kCopyDone, simply find the argument that it is
290 // copied from.
291 while (instruction->opcode() == HloOpcode::kCopyDone) {
292 if (instruction->operand(0)->opcode() != HloOpcode::kCopyStart) {
293 return tensorflow::errors::FailedPrecondition(
294 "kCopyDone has an argument different than a kCopyStart.");
295 }
296 instruction = instruction->mutable_operand(0)->mutable_operand(0);
297 }
298 if (instruction->opcode() == HloOpcode::kParameter) {
299 return tensorflow::errors::FailedPrecondition(
300 "Cannot evaluate a parameter.");
301 }
302 if (!hlo_query::AllOperandsAreConstants(*instruction)) {
303 return tensorflow::errors::FailedPrecondition(
304 "Not all operands are constants.");
305 }
306
307 arg_literals_.clear();
308 evaluated_.clear();
309
310 TF_RETURN_IF_ERROR(Preprocess(instruction));
311 TF_RETURN_IF_ERROR(instruction->Visit(this));
312 TF_RETURN_IF_ERROR(Postprocess(instruction));
313 return GetEvaluatedLiteralFor(instruction).Clone();
314 }
315
TryEvaluate(HloInstruction * instruction,Literal * result)316 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
317 CHECK(result != nullptr);
318 auto result_or = Evaluate(instruction);
319 if (!result_or.ok()) {
320 VLOG(1) << "TryEvaluate failed:" << result_or.status();
321 return false;
322 }
323
324 *result = result_or.ConsumeValueOrDie();
325 return true;
326 }
327
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)328 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
329 const HloInstruction* instruction,
330 const std::unordered_map<const HloInstruction*, const Literal*>&
331 substitutions) {
332 std::vector<std::unique_ptr<HloInstruction>> owned_operands;
333 for (const HloInstruction* operand : instruction->operands()) {
334 auto it = substitutions.find(operand);
335 if (it == substitutions.end()) {
336 owned_operands.push_back(operand->Clone());
337 } else {
338 owned_operands.push_back(
339 HloInstruction::CreateConstant(it->second->Clone()));
340 }
341 }
342
343 std::vector<HloInstruction*> operands;
344 operands.reserve(owned_operands.size());
345 for (auto& operand : owned_operands) {
346 operands.push_back(operand.get());
347 }
348
349 std::unique_ptr<HloInstruction> cloned_instruction =
350 instruction->CloneWithNewOperands(instruction->shape(), operands);
351 auto result = Evaluate(cloned_instruction.get());
352
353 return result;
354 }
355
EvaluateElementwiseBinaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs)356 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
357 HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
358 std::unique_ptr<HloInstruction> lhs_instr =
359 HloInstruction::CreateConstant(lhs.Clone());
360 std::unique_ptr<HloInstruction> rhs_instr =
361 HloInstruction::CreateConstant(rhs.Clone());
362
363 std::unique_ptr<HloInstruction> cloned_instruction =
364 HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
365 rhs_instr.get());
366 auto result = Evaluate(cloned_instruction.get());
367
368 return result;
369 }
370
EvaluateElementwiseTernaryOp(HloOpcode opcode,const Literal & lhs,const Literal & rhs,const Literal & ehs)371 StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
372 HloOpcode opcode, const Literal& lhs, const Literal& rhs,
373 const Literal& ehs) {
374 std::unique_ptr<HloInstruction> lhs_instr =
375 HloInstruction::CreateConstant(lhs.Clone());
376 std::unique_ptr<HloInstruction> rhs_instr =
377 HloInstruction::CreateConstant(rhs.Clone());
378 std::unique_ptr<HloInstruction> ehs_instr =
379 HloInstruction::CreateConstant(ehs.Clone());
380 TF_ASSIGN_OR_RETURN(auto output_shape,
381 ShapeInference::InferTernaryOpShape(
382 opcode, lhs.shape(), rhs.shape(), ehs.shape()));
383 std::unique_ptr<HloInstruction> cloned_instruction =
384 HloInstruction::CreateTernary(output_shape, opcode, lhs_instr.get(),
385 rhs_instr.get(), ehs_instr.get());
386 return Evaluate(cloned_instruction.get());
387 }
388
EvaluateElementwiseCompareOp(ComparisonDirection direction,const Literal & lhs,const Literal & rhs)389 StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
390 ComparisonDirection direction, const Literal& lhs, const Literal& rhs) {
391 std::unique_ptr<HloInstruction> lhs_instr =
392 HloInstruction::CreateConstant(lhs.Clone());
393 std::unique_ptr<HloInstruction> rhs_instr =
394 HloInstruction::CreateConstant(rhs.Clone());
395
396 std::unique_ptr<HloInstruction> cloned_instruction =
397 HloInstruction::CreateCompare(
398 ShapeUtil::ChangeElementType(lhs.shape(), PRED), lhs_instr.get(),
399 rhs_instr.get(), direction);
400 auto result = Evaluate(cloned_instruction.get());
401
402 return result;
403 }
404
EvaluateElementwiseUnaryOp(HloOpcode opcode,const Literal & operand)405 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
406 HloOpcode opcode, const Literal& operand) {
407 std::unique_ptr<HloInstruction> operand_instr =
408 HloInstruction::CreateConstant(operand.Clone());
409
410 std::unique_ptr<HloInstruction> cloned_instruction =
411 HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
412 auto result = Evaluate(cloned_instruction.get());
413
414 return result;
415 }
416
EvaluateDotOp(const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,const Literal & lhs,const Literal & rhs)417 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
418 const DotDimensionNumbers& dim_numbers,
419 const PrecisionConfig& precision_config, const Literal& lhs,
420 const Literal& rhs) {
421 std::unique_ptr<HloInstruction> lhs_instr =
422 HloInstruction::CreateConstant(lhs.Clone());
423 std::unique_ptr<HloInstruction> rhs_instr =
424 HloInstruction::CreateConstant(rhs.Clone());
425
426 TF_ASSIGN_OR_RETURN(Shape dot_shape,
427 ShapeInference::InferDotOpShape(
428 lhs.shape(), rhs.shape(), dim_numbers,
429 /*preferred_element_type=*/absl::nullopt));
430
431 std::unique_ptr<HloInstruction> cloned_instruction =
432 HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
433 dim_numbers, precision_config);
434 return Evaluate(cloned_instruction.get());
435 }
436
HandleBitcast(HloInstruction * bitcast)437 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
438 const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
439 Literal result(bitcast->shape());
440 TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
441 memcpy(result.untyped_data(), operand_literal.untyped_data(),
442 operand_literal.size_bytes());
443 evaluated_[bitcast] = std::move(result);
444 return Status::OK();
445 }
446
HandleGetDimensionSize(HloInstruction * get_dimension_size)447 Status HloEvaluator::HandleGetDimensionSize(
448 HloInstruction* get_dimension_size) {
449 HloInstruction* operand = get_dimension_size->mutable_operand(0);
450 int64_t dim = get_dimension_size->dimension();
451 if (dynamic_dimension_inference_ == nullptr) {
452 return InvalidArgument(
453 "Evaluator cannot evaluate get_dimension_size without "
454 "set_dynamic_dimension_inference.");
455 }
456 HloInstruction* dynamic_size =
457 dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
458 if (dynamic_size != nullptr) {
459 evaluated_[get_dimension_size] =
460 GetEvaluatedLiteralFor(dynamic_size).Clone();
461 return Status::OK();
462 }
463
464 const Shape& shape = get_dimension_size->operand(0)->shape();
465 Literal output(ShapeUtil::MakeShape(S32, {}));
466 output.PopulateWithValue(
467 static_cast<int32>(shape.dimensions(get_dimension_size->dimension())));
468 evaluated_[get_dimension_size] = std::move(output);
469 return Status::OK();
470 }
471
HandleSetDimensionSize(HloInstruction * set_dimension_size)472 Status HloEvaluator::HandleSetDimensionSize(
473 HloInstruction* set_dimension_size) {
474 const Literal& operand_literal =
475 GetEvaluatedLiteralFor(set_dimension_size->operand(0));
476 Literal result(set_dimension_size->shape());
477 memcpy(result.untyped_data(), operand_literal.untyped_data(),
478 operand_literal.size_bytes());
479 const Literal& size_literal =
480 GetEvaluatedLiteralFor(set_dimension_size->operand(1));
481 result.SetDynamicSize(set_dimension_size->dimension(),
482 size_literal.Get<int32>({}));
483 evaluated_[set_dimension_size] = std::move(result);
484 return Status::OK();
485 }
486
HandleParameter(HloInstruction * parameter)487 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
488 // Nothing to do other than sanity checks. Parameters' values are stored in
489 // arg_literals_.
490 CHECK_LT(parameter->parameter_number(), arg_literals_.size());
491
492 #ifndef NDEBUG
493 const Literal* input_literal = arg_literals_[parameter->parameter_number()];
494 VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
495 DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
496 input_literal->shape()))
497 << "parameter shape is: "
498 << ShapeUtil::HumanStringWithLayout(parameter->shape())
499 << ", but input literal shape is: "
500 << ShapeUtil::HumanStringWithLayout(input_literal->shape());
501 #endif
502
503 return Status::OK();
504 }
505
HandleConstant(HloInstruction *)506 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
507
HandleReshape(HloInstruction * reshape)508 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
509 TF_ASSIGN_OR_RETURN(
510 evaluated_[reshape],
511 GetEvaluatedLiteralFor(reshape->operand(0))
512 .Reshape(AsInt64Slice(reshape->shape().dimensions())));
513 return Status::OK();
514 }
515
HandleTranspose(HloInstruction * transpose)516 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
517 evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
518 .Transpose(transpose->dimensions());
519 return Status::OK();
520 }
521
HandleConcatenate(HloInstruction * concatenate)522 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
523 absl::Span<HloInstruction* const> operands(concatenate->operands());
524 // The result concatenate dimension is going to be the sum of all
525 // concatenate dimensions of the operands taking part of the operation.
526 const Shape& reference_shape = operands[0]->shape();
527 CHECK(reference_shape.IsArray());
528 const int64_t rank = reference_shape.rank();
529 const int64_t concat_dim = concatenate->dimensions()[0];
530 CHECK_GE(concat_dim, 0);
531 CHECK_LT(concat_dim, rank);
532
533 DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
534 reference_shape.dimensions().end());
535
536 for (int64_t i = 1; i < operands.size(); ++i) {
537 const Shape& operand_shape = operands[i]->shape();
538 CHECK(operand_shape.IsArray());
539 // Accumulate the concat dimension from all tensors taking part to the
540 // operation.
541 concat_dimensions[concat_dim] +=
542 ShapeUtil::GetDimension(operand_shape, concat_dim);
543 }
544
545 auto result_literal = LiteralUtil::CreateFromDimensions(
546 reference_shape.element_type(), concat_dimensions);
547 DimensionVector source_indices(rank, 0);
548 DimensionVector dest_indices(concat_dimensions.size(), 0);
549
550 for (auto operand : operands) {
551 const Shape& operand_shape = operand->shape();
552 TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
553 GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
554 AsInt64Slice(operand_shape.dimensions())));
555 dest_indices[concat_dim] +=
556 ShapeUtil::GetDimension(operand_shape, concat_dim);
557 }
558
559 evaluated_[concatenate] = std::move(result_literal);
560 return Status::OK();
561 }
562
HandleIsFinite(HloInstruction * is_finite)563 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
564 auto operand = is_finite->operand(0);
565 auto elem_ty = operand->shape().element_type();
566 switch (elem_ty) {
567 case PRED:
568 case TUPLE:
569 case OPAQUE_TYPE:
570 case TOKEN:
571 case S8:
572 case S16:
573 case S32:
574 case S64:
575 case U8:
576 case U16:
577 case U32:
578 case U64:
579 case C64:
580 case C128:
581 // Explicitly enumerate all types in this switch so that when we add a new
582 // type, we'll get a compile error here.
583 case PRIMITIVE_TYPE_INVALID:
584 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
585 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
586 return InvalidArgument(
587 "expected element type in shape to be floating point, but "
588 "got: %s",
589 PrimitiveType_Name(elem_ty));
590
591 case F16: {
592 auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
593 is_finite,
594 [](Eigen::half elem_operand) {
595 return std::isfinite(static_cast<float>(elem_operand));
596 },
597 GetEvaluatedLiteralFor(operand));
598 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
599 break;
600 }
601 case BF16: {
602 auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
603 is_finite,
604 [](bfloat16 elem_operand) {
605 return std::isfinite(static_cast<float>(elem_operand));
606 },
607 GetEvaluatedLiteralFor(operand));
608 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
609 break;
610 }
611 case F32: {
612 auto result_or = ElementWiseUnaryOpImpl<bool, float>(
613 is_finite,
614 [](float elem_operand) { return std::isfinite(elem_operand); },
615 GetEvaluatedLiteralFor(operand));
616 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
617 break;
618 }
619 case F64: {
620 auto result_or = ElementWiseUnaryOpImpl<bool, double>(
621 is_finite,
622 [](double elem_operand) { return std::isfinite(elem_operand); },
623 GetEvaluatedLiteralFor(operand));
624 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
625 break;
626 }
627 }
628
629 return Status::OK();
630 }
631
HandleReal(HloInstruction * real)632 Status HloEvaluator::HandleReal(HloInstruction* real) {
633 auto operand = real->operand(0);
634 switch (operand->shape().element_type()) {
635 case BF16: {
636 auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
637 real, [](bfloat16 elem_operand) { return elem_operand; },
638 GetEvaluatedLiteralFor(operand));
639 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
640 break;
641 }
642 case C64: {
643 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
644 real, [](complex64 elem_operand) { return std::real(elem_operand); },
645 GetEvaluatedLiteralFor(operand));
646 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
647 break;
648 }
649 case C128: {
650 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
651 real, [](complex128 elem_operand) { return std::real(elem_operand); },
652 GetEvaluatedLiteralFor(operand));
653 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
654 break;
655 }
656 case F16: {
657 auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
658 real, [](Eigen::half elem_operand) { return elem_operand; },
659 GetEvaluatedLiteralFor(operand));
660 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
661 break;
662 }
663 case F32: {
664 auto result_or = ElementWiseUnaryOpImpl<float, float>(
665 real, [](float elem_operand) { return elem_operand; },
666 GetEvaluatedLiteralFor(operand));
667 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
668 break;
669 }
670 case F64: {
671 auto result_or = ElementWiseUnaryOpImpl<double, double>(
672 real, [](double elem_operand) { return elem_operand; },
673 GetEvaluatedLiteralFor(operand));
674 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
675 break;
676 }
677 default:
678 LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
679 << PrimitiveType_Name(operand->shape().element_type());
680 }
681
682 return Status::OK();
683 }
684
HandleImag(HloInstruction * imag)685 Status HloEvaluator::HandleImag(HloInstruction* imag) {
686 auto operand = imag->operand(0);
687 switch (operand->shape().element_type()) {
688 case C64: {
689 auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
690 imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
691 GetEvaluatedLiteralFor(imag->operand(0)));
692
693 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
694 break;
695 }
696 case C128: {
697 auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
698 imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
699 GetEvaluatedLiteralFor(imag->operand(0)));
700
701 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
702 break;
703 }
704 default:
705 LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
706 << PrimitiveType_Name(operand->shape().element_type());
707 }
708
709 return Status::OK();
710 }
711
HandleComplex(HloInstruction * complex)712 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
713 const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
714 const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
715 TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
716
717 Literal result(complex->shape());
718 switch (complex->shape().element_type()) {
719 case C64: {
720 TF_RETURN_IF_ERROR(
721 result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
722 return std::complex<float>(real.Get<float>(multi_index),
723 imag.Get<float>(multi_index));
724 }));
725 break;
726 }
727 case C128: {
728 TF_RETURN_IF_ERROR(
729 result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
730 return std::complex<double>(real.Get<double>(multi_index),
731 imag.Get<double>(multi_index));
732 }));
733 break;
734 }
735 default:
736 LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
737 << PrimitiveType_Name(complex->shape().element_type());
738 }
739
740 evaluated_[complex] = std::move(result);
741 return Status::OK();
742 }
743
HandleCompare(HloInstruction * compare)744 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
745 ComparisonDirection direction = compare->comparison_direction();
746 auto lhs = compare->operand(0);
747 auto rhs = compare->operand(1);
748 DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
749 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
750
751 TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
752
753 const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
754 const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
755
756 // Note here we switch on the operand's type.
757 switch (lhs->shape().element_type()) {
758 case PRED: {
759 TF_ASSIGN_OR_RETURN(
760 evaluated_[compare],
761 Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
762 } break;
763 case U8: {
764 TF_ASSIGN_OR_RETURN(evaluated_[compare],
765 Compare<uint8>(compare->shape(), direction,
766 lhs_literal, rhs_literal));
767 } break;
768 case U16: {
769 TF_ASSIGN_OR_RETURN(evaluated_[compare],
770 Compare<uint16>(compare->shape(), direction,
771 lhs_literal, rhs_literal));
772 } break;
773 case U32: {
774 TF_ASSIGN_OR_RETURN(evaluated_[compare],
775 Compare<uint32>(compare->shape(), direction,
776 lhs_literal, rhs_literal));
777 } break;
778 case U64: {
779 TF_ASSIGN_OR_RETURN(evaluated_[compare],
780 Compare<uint64>(compare->shape(), direction,
781 lhs_literal, rhs_literal));
782 } break;
783 case S8: {
784 TF_ASSIGN_OR_RETURN(
785 evaluated_[compare],
786 Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
787 } break;
788 case S16: {
789 TF_ASSIGN_OR_RETURN(evaluated_[compare],
790 Compare<int16>(compare->shape(), direction,
791 lhs_literal, rhs_literal));
792 } break;
793 case S32: {
794 TF_ASSIGN_OR_RETURN(evaluated_[compare],
795 Compare<int32>(compare->shape(), direction,
796 lhs_literal, rhs_literal));
797 } break;
798 case S64: {
799 TF_ASSIGN_OR_RETURN(evaluated_[compare],
800 Compare<int64>(compare->shape(), direction,
801 lhs_literal, rhs_literal));
802 } break;
803 case F16: {
804 TF_ASSIGN_OR_RETURN(
805 evaluated_[compare],
806 Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
807 } break;
808 case BF16: {
809 TF_ASSIGN_OR_RETURN(evaluated_[compare],
810 Compare<bfloat16>(compare->shape(), direction,
811 lhs_literal, rhs_literal));
812 } break;
813 case F32: {
814 TF_ASSIGN_OR_RETURN(evaluated_[compare],
815 Compare<float>(compare->shape(), direction,
816 lhs_literal, rhs_literal));
817 } break;
818 case F64: {
819 TF_ASSIGN_OR_RETURN(evaluated_[compare],
820 Compare<double>(compare->shape(), direction,
821 lhs_literal, rhs_literal));
822 } break;
823 case C64: {
824 TF_ASSIGN_OR_RETURN(evaluated_[compare],
825 Compare<complex64>(compare->shape(), direction,
826 lhs_literal, rhs_literal));
827 } break;
828 case C128: {
829 TF_ASSIGN_OR_RETURN(evaluated_[compare],
830 Compare<complex128>(compare->shape(), direction,
831 lhs_literal, rhs_literal));
832 } break;
833 default:
834 LOG(FATAL) << "HandleCompare: unknown primitive type: "
835 << PrimitiveType_Name(lhs->shape().element_type());
836 }
837
838 return Status::OK();
839 }
840
HandleTuple(HloInstruction * tuple)841 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
842 std::vector<const Literal*> operand_literals;
843 for (auto operand : tuple->operands()) {
844 operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
845 }
846
847 evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
848 return Status::OK();
849 }
850
851 namespace {
852
853 // These helper templates convert the data type and are intended to be used only
854 // within the DFT implementation below. The special case is IRFFT, where the
855 // specialization drops imaginary parts of complex values and returns real
856 // numbers.
857 template <typename ToType, typename FromType>
858 struct TypeConverter {
GetAsxla::__anon5462a6e82011::TypeConverter859 static inline ToType GetAs(FromType value) {
860 return static_cast<ToType>(value);
861 }
862 };
863
864 template <typename FromType>
865 struct TypeConverter<float, FromType> {
GetAsxla::__anon5462a6e82011::TypeConverter866 static inline float GetAs(FromType value) {
867 return static_cast<float>(value.real());
868 }
869 };
870
871 // This class implements the discrete Fourier transform. All transform types
872 // (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary rank and
873 // length of each dimension of the transform, and arbitrary layouts of the input
874 // and output literals. The class template parameter must be a complex type, and
875 // all internal calculations will be performed using this type.
876 //
877 // The input literal provides input data, which must be complex64 for FFT, IFFT,
878 // IRFFT transforms and float for RFFT. The transform is computed over the
879 // innermost dimensions of the input, thus the rank of the input data must be
880 // same as fft_rank or larger. The input is expected to provide Ni values along
881 // each transform axis with one exception: for IRFFT, only (N0 / 2) + 1 values
882 // are needed along the X axis (the innermost index). To increase flexibility,
883 // this implementation can handle mismatches between the input size and
884 // transform lengths by either dropping extra input values or using zeroes in
885 // place of missing input values as necessary. If the input data has rank higher
886 // than the transform, the transform is applied for each valid combination of
887 // the higher-ranking indices.
888 //
889 // The output contains complex64 values for FFT, IFFT, RFFT, and float values
890 // for IRFFT. The rank of the output as well as the sizes of the dimensions
891 // above the rank of the transform must match those of the input. Sizes of the
892 // output's "fft_rank" innermost dimensions are expected to match the length of
893 // the transform along respective axes with one exception: for RFFT, the output
894 // is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
895 // length(s) mismatch, the FFT output is trimmed to fit into the provided output
896 // shape, or the output is padded with zero values appropriately.
897 //
898 // For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
899 // input array will perform two transforms over the [][15][17] data in the sub
900 // arrays [0][][] and [1][][], dropping the values along axis X and padding axis
901 // Y with zeroes to create 16x16 working sets, and generating
902 // complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
903 // complex64[64][16][9] input array will use all input values and will produce
904 // float[64][16][16] output.
905 //
906 // The implementation of the 1D transform for lengths, that are powers of 2, is
907 // the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform
908 // lengths, a straightforward, but slow, loop nest is used. The transforms of
909 // higher ranks apply sets of 1D transforms along each axis. For example, the 2D
910 // transform is computed by applying 1D transforms to each column followed by
911 // applying 1D transforms to each row.
912 //
913 // In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
914 // time, where Ni is the length of the transform's i-th dimension. However, for
915 // dimension lengths, which are powers of 2, the run time along these dimensions
916 // is reduced to log(Ni) in the summation, giving the runtime of
917 // O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case.
918 //
919 template <typename ComplexType>
920 class FftTransform {
921 public:
FftTransform(HloInstruction * fft)922 explicit FftTransform(HloInstruction* fft)
923 : fft_type_(fft->fft_type()),
924 fft_rank_(fft->fft_length().size()),
925 fft_lengths_(fft->fft_length()) {
926 // Make fft_lengths_[0] the minormost dimension.
927 absl::c_reverse(fft_lengths_);
928 }
929
ComputeFft(HloInstruction * fft,const Literal & input_literal,Literal * output_literal)930 Status ComputeFft(HloInstruction* fft, const Literal& input_literal,
931 Literal* output_literal) {
932 const Shape& input_shape = input_literal.shape();
933 const Shape& output_shape = fft->shape();
934
935 TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape));
936
937 const auto fft_strides = ComputeStrides(fft_lengths_);
938
939 // Working set size.
940 const int64_t fft_size = fft_strides[fft_rank_];
941
942 if (fft_size > 0) {
943 // Linearized working data set.
944 std::vector<ComplexType> data(fft_size);
945
946 // Temporary buffer allocated once and used in 1D sweeps. For dimension
947 // length values that are powers of 2, the buffer should be twice as large
948 // to simultaneously hold input and output in Fft1D() above.
949 int64_t buffer_size = 0;
950 for (auto len : fft_lengths_) {
951 int64_t size = IsPowerOfTwo(static_cast<uint64>(len)) ? len * 2 : len;
952 buffer_size = std::max(buffer_size, size);
953 }
954 std::vector<ComplexType> buffer(buffer_size);
955
956 // Sizes of each axis of input and output literals.
957 const auto input_lengths = GetDimensionLengths(input_literal);
958 const auto output_lengths = GetDimensionLengths(*output_literal);
959
960 // Strides for generating linearized indices into multidimensional arrays.
961 const auto input_strides = ComputeStrides(input_lengths, input_literal);
962 const auto output_strides =
963 ComputeStrides(output_lengths, *output_literal);
964
965 // Visit all elements in the dimensions with ranks above the FFT rank. For
966 // each such element invoke the transform. Use separate indices for the
967 // input and the output to allow different layouts.
968 auto base_case = [&](int64_t axis, int64_t output_index,
969 int64_t input_index, bool within_src_bounds) {
970 if (axis == fft_rank_ - 1) {
971 // Base case: copy the data from the input literal, apply the
972 // transform, and copy the result to the output literal.
973 CHECK(within_src_bounds);
974 bool input_is_zero = CopyDataFromInput(
975 input_literal, input_index, fft_size, fft_lengths_, fft_strides,
976 input_lengths, input_strides, absl::MakeSpan(data));
977 if (!input_is_zero) {
978 // Make 1D sweeps along each transform axis.
979 Sweep(fft_lengths_, fft_strides, absl::MakeSpan(data),
980 absl::MakeSpan(buffer));
981 }
982 CopyDataToOutput(absl::MakeSpan(data), output_index, fft_lengths_,
983 fft_strides, output_lengths, output_strides,
984 output_literal);
985 return true;
986 }
987 return false;
988 };
989 GenerateIndices(output_lengths, output_strides, input_lengths,
990 input_strides, input_shape.rank(), 0, 0, base_case);
991 }
992
993 return Status::OK();
994 }
995
996 private:
997 // Common code used by 1D implementations, which copies data from the input to
998 // the contiguous buffer. Returns true if all copied values are zero.
GatherToBuffer(absl::Span<ComplexType> data,int64_t length,int64_t start,int64_t stride,bool expand_input,absl::Span<ComplexType> buffer)999 static bool GatherToBuffer(absl::Span<ComplexType> data, int64_t length,
1000 int64_t start, int64_t stride, bool expand_input,
1001 absl::Span<ComplexType> buffer) {
1002 CHECK_GE(buffer.size(), length);
1003 bool input_is_zero = true;
1004 const int64_t ub = expand_input ? length / 2 + 1 : length;
1005 CHECK_GE(data.size(), start + (ub - 1) * stride);
1006 for (int64_t k = 0; k < ub; k++) {
1007 ComplexType value = data[start + k * stride];
1008 input_is_zero &= value == ComplexType(0.0, 0.0);
1009 buffer[k] = value;
1010 if (expand_input) {
1011 // Use conjugates of the values at indices [1 ... (ub - 2)] when the
1012 // length is even and at indices [1 ... (ub - 1)] when the length is odd
1013 // to calculate missing values at indices [(length - 1) ... ub].
1014 if (k > 0 && k < (length - ub + 1)) {
1015 buffer[length - k] = std::conj(value);
1016 }
1017 }
1018 }
1019 return input_is_zero;
1020 }
1021
1022 // Returns (conjugated, if 'inverse' is true) k-th twiddle for the given
1023 // length.
Twiddle(int64_t k,int64_t length,bool inverse)1024 static inline ComplexType Twiddle(int64_t k, int64_t length, bool inverse) {
1025 auto coeff = std::exp(ComplexType(0.0, -2.0 * M_PI * k / length));
1026 return inverse ? std::conj(coeff) : coeff;
1027 }
1028
1029 // Straightforward implementation of 1D DFT transform of arbitrary length.
1030 // Uses passed-in start index and stride to gather inputs from the data vector
1031 // into the preallocated buffer, computes the result, and writes it back to
1032 // the same locations in the data vector. Runs in O(length^2) time.
1033 //
1034 // Parameters contract_output and expand_input are used to avoid unnecessary
1035 // calculations. When contract_output is set to true, then only (length / 2) +
1036 // 1 output values are computed. When expand_input is set to true, then
1037 // (length / 2) + 1 values from the data set are used to re-create the full
1038 // set of size 'length', on which the transform is then performed.
1039 //
NaiveDft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1040 static void NaiveDft1D(int64_t length, int64_t start, int64_t stride,
1041 bool inverse, bool contract_output, bool expand_input,
1042 absl::Span<ComplexType> data,
1043 absl::Span<ComplexType> buffer) {
1044 const bool input_is_zero =
1045 GatherToBuffer(data, length, start, stride, expand_input, buffer);
1046
1047 if (!input_is_zero) {
1048 const int64_t ub = contract_output ? length / 2 + 1 : length;
1049 for (int64_t k = 0; k < ub; k++) {
1050 ComplexType value = ComplexType(0.0, 0.0);
1051 for (int n = 0; n < length; n++) {
1052 value += buffer[n] * Twiddle(n * k, length, inverse);
1053 }
1054 data[start + k * stride] =
1055 inverse ? value / ComplexType(length, 0.0) : value;
1056 }
1057 }
1058 }
1059
1060 // Non-recursive implementation of the Cooley-Tukey radix-2 decimation in
1061 // time. Performs 1D FFT transform for the lengths, which are powers of 2.
1062 // Runs in O(length * log(length)) time. Uses the same parameters as the naive
1063 // implementation above, except that the preallocated buffer must be at least
1064 // twice as big as the length of the transform, because the buffer is used to
1065 // hold both input and output values for each stage of the transform.
1066 //
Fft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1067 static void Fft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1068 bool contract_output, bool expand_input,
1069 absl::Span<ComplexType> data,
1070 absl::Span<ComplexType> buffer) {
1071 CHECK(IsPowerOfTwo(static_cast<uint64>(length)));
1072 const bool input_is_zero =
1073 GatherToBuffer(data, length, start, stride, expand_input, buffer);
1074
1075 if (!input_is_zero) {
1076 auto generate_twiddles = [](int64_t length, bool inverse) {
1077 std::vector<ComplexType> twiddles;
1078 // Need only half the twiddles.
1079 for (int64_t k = 0; k < length / 2; k++) {
1080 twiddles.push_back(Twiddle(k, length, inverse));
1081 }
1082 return twiddles;
1083 };
1084
1085 // Indices into the parts of the buffer used for input and output values.
1086 int64_t in_base = length;
1087 int64_t out_base = 0;
1088
1089 // At each stage, we "split" the input data into num_blocks, with
1090 // block_size values in each block.
1091 for (int64_t num_blocks = 1; num_blocks < length; num_blocks *= 2) {
1092 // Swap input and output parts of the buffer.
1093 std::swap(in_base, out_base);
1094 auto twiddles = generate_twiddles(num_blocks * 2, inverse);
1095 const int64_t block_size = length / num_blocks;
1096 const int64_t next_iteration_block_size = block_size / 2;
1097 for (int64_t block = 0; block < num_blocks; block++) {
1098 const int64_t in_offset = in_base + block * block_size;
1099 const int64_t out_offset =
1100 out_base + block * next_iteration_block_size;
1101 // For each (even, odd) pair of values in the block, calculate two
1102 // output values as even + twiddle * odd and even - twiddle * odd.
1103 for (int64_t pair = 0; pair < block_size / 2; pair++) {
1104 const ComplexType even = buffer[in_offset + pair];
1105 const ComplexType odd = buffer[in_offset + block_size / 2 + pair];
1106 const ComplexType twiddled_odd = twiddles[block] * odd;
1107 buffer[out_offset + pair] = even + twiddled_odd;
1108 buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
1109 }
1110 }
1111 }
1112 // Copy computed result back to data.
1113 const int64_t ub = contract_output ? length / 2 + 1 : length;
1114 for (int64_t k = 0; k < ub; k++) {
1115 ComplexType value = buffer[out_base + k];
1116 data[start + k * stride] =
1117 inverse ? value / ComplexType(length, 0.0) : value;
1118 }
1119 }
1120 }
1121
1122 // Determine, which implementation of 1D transform to use and call it.
Dft1D(int64_t length,int64_t start,int64_t stride,bool inverse,bool contract_output,bool expand_input,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1123 static void Dft1D(int64_t length, int64_t start, int64_t stride, bool inverse,
1124 bool contract_output, bool expand_input,
1125 absl::Span<ComplexType> data,
1126 absl::Span<ComplexType> buffer) {
1127 if (IsPowerOfTwo(static_cast<uint64>(length))) {
1128 Fft1D(length, start, stride, inverse, contract_output, expand_input, data,
1129 buffer);
1130 } else {
1131 NaiveDft1D(length, start, stride, inverse, contract_output, expand_input,
1132 data, buffer);
1133 }
1134 }
1135
1136 // Helper to reverse the order of dimension lengths in the passed-in literal.
GetDimensionLengths(const Literal & literal)1137 static std::vector<int64> GetDimensionLengths(const Literal& literal) {
1138 auto dimensions = literal.shape().dimensions();
1139 return std::vector<int64>(dimensions.rbegin(), dimensions.rend());
1140 }
1141
1142 // Helper to compute strides for creating linear indices into multidimensional
1143 // data from the dimension lengths and the layout. Returns a new vector of
1144 // size lengths.size() + 1. The last element of the returned vector at index
1145 // [lengths.size()] contains the product of all dimension lengths.
ComputeStrides(const absl::Span<const int64> lengths,const Layout & layout)1146 static std::vector<int64> ComputeStrides(
1147 const absl::Span<const int64> lengths, const Layout& layout) {
1148 const int64_t num_dimensions = lengths.size();
1149
1150 // Make sure that the layout length matches the number of dimensions.
1151 CHECK_EQ(num_dimensions, layout.minor_to_major_size());
1152
1153 // Calculate strides using layout-specified ordering of the dimensions and
1154 // place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
1155 std::vector<int64> strides(num_dimensions + 1);
1156 int64_t stride = 1;
1157 for (int64_t i = 0; i < num_dimensions; i++) {
1158 // Reverse the ordering of the dimensions in the layout.
1159 const int64_t index = (num_dimensions - 1) - layout.minor_to_major(i);
1160 strides[index] = stride;
1161 stride *= lengths[index];
1162 }
1163 strides[num_dimensions] = stride;
1164
1165 return strides;
1166 }
1167
1168 // Compute strides as above using the default layout.
ComputeStrides(const absl::Span<const int64> lengths)1169 static std::vector<int64> ComputeStrides(
1170 const absl::Span<const int64> lengths) {
1171 return ComputeStrides(lengths,
1172 LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
1173 }
1174
1175 // Compute strides as above using the layout from the literal, if available.
ComputeStrides(const absl::Span<const int64> lengths,const Literal & literal)1176 static std::vector<int64> ComputeStrides(
1177 const absl::Span<const int64> lengths, const Literal& literal) {
1178 return literal.shape().has_layout()
1179 ? ComputeStrides(lengths, literal.shape().layout())
1180 : ComputeStrides(lengths);
1181 }
1182
1183 // Make 1D sweeps along each transform axis.
Sweep(const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,absl::Span<ComplexType> data,absl::Span<ComplexType> buffer)1184 void Sweep(const absl::Span<const int64> fft_lengths,
1185 const absl::Span<const int64> fft_strides,
1186 absl::Span<ComplexType> data, absl::Span<ComplexType> buffer) {
1187 const bool inverse =
1188 fft_type_ == FftType::IFFT || fft_type_ == FftType::IRFFT;
1189 const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1190 const bool output_is_truncated = fft_type_ == FftType::RFFT;
1191
1192 // Recursively visit each column of the data along the sweep_axis. Calculate
1193 // linearized index of that column's first element and the stride, then
1194 // invoke 1D transform. For RFFT, avoid calculating unused output values:
1195 // first, compute only (length_x / 2) + 1 values along the X axis, then
1196 // limit the X coordinate to [0 ... (length / 2)] during the sweeps along
1197 // other axes. Similarly, for IRFFT sweep along higher dimensions first,
1198 // while keeping the X coordinate in the [0 ... (length / 2)] range, then
1199 // re-create negative frequencies omitted in the input and perform the
1200 // full-length transform along the X axis in the last sweep.
1201 std::function<void(int64_t, int64_t, int64_t)> sweep =
1202 [&](int64_t sweep_axis, int64_t axis, int64_t start) {
1203 if (axis < 0) {
1204 // Base case: invoke 1D transform.
1205 const int64_t length = fft_lengths[sweep_axis];
1206 const int64_t stride = fft_strides[sweep_axis];
1207 const bool expand_input = input_is_truncated && sweep_axis == 0;
1208 const bool contract_oputput =
1209 output_is_truncated && sweep_axis == 0;
1210 Dft1D(length, start, stride, inverse, contract_oputput,
1211 expand_input, data, buffer);
1212 } else if (axis == sweep_axis) {
1213 // Visit only the elements with coordinate 0 along the sweep axis.
1214 sweep(sweep_axis, axis - 1, start);
1215 } else {
1216 const int64_t length = fft_lengths[axis];
1217 const bool is_truncated = input_is_truncated || output_is_truncated;
1218 const int64_t ub =
1219 is_truncated && axis == 0 ? (length / 2) + 1 : length;
1220 for (int64_t i = 0; i < ub; i++) {
1221 sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
1222 }
1223 }
1224 };
1225 if (input_is_truncated) {
1226 // Sweep along the X axis last for IRFFT.
1227 for (int64_t sweep_axis = fft_rank_ - 1; sweep_axis >= 0; sweep_axis--) {
1228 sweep(sweep_axis, fft_rank_ - 1, 0);
1229 }
1230 } else {
1231 // Sweep along the X axis first for RFFT. The order does not matter for
1232 // FFT and IFFT types; handle them here as well.
1233 for (int64_t sweep_axis = 0; sweep_axis < fft_rank_; sweep_axis++) {
1234 sweep(sweep_axis, fft_rank_ - 1, 0);
1235 }
1236 }
1237 }
1238
1239 // This template generates two linearized indices, which can be used to access
1240 // multidimensional arrays. It uses a recursive function, which passes the
1241 // indices to the user-supplied callback function. The destination index is
1242 // always within dst_lengths[] bounds. The boolean parameter within_src_bounds
1243 // indicates whether the source index is within src_lengths[] bounds.
1244 //
1245 // The value returned from the callback function controls the recursion depth.
1246 // Returning true indicates that the base case had been hit and the recursion
1247 // stops. Otherwise, the recursion proceeds along the next less-major axis.
1248 //
1249 // For example, the base case when the axis value becomes negative invokes the
1250 // callback function for each possible index within dst_lengths[] bounds. The
1251 // base case when the axis value is equal to zero limits the indices to point
1252 // only to first elements along the minor-most dimension, allowing the
1253 // callback function to handle all values along the X axis.
1254 //
1255 template <typename BaseFn>
GenerateIndices(const absl::Span<const int64> dst_lengths,const absl::Span<const int64> dst_strides,const absl::Span<const int64> src_lengths,const absl::Span<const int64> src_strides,int64_t rank,int64_t dst_start,int64_t src_start,BaseFn && base)1256 static void GenerateIndices(const absl::Span<const int64> dst_lengths,
1257 const absl::Span<const int64> dst_strides,
1258 const absl::Span<const int64> src_lengths,
1259 const absl::Span<const int64> src_strides,
1260 int64_t rank, int64_t dst_start,
1261 int64_t src_start, BaseFn&& base) {
1262 CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
1263 CHECK_GE(dst_lengths.size(), rank);
1264 CHECK_EQ(src_lengths.size() + 1, src_strides.size());
1265 CHECK_GE(src_lengths.size(), rank);
1266
1267 std::function<void(int64_t, int64_t, int64_t, bool)> generate =
1268 [&](int64_t axis, int64_t dst_index, int64_t src_index,
1269 bool within_src_bounds) {
1270 if (!base(axis, dst_index, src_index, within_src_bounds)) {
1271 for (int64_t i = 0; i < dst_lengths[axis]; i++) {
1272 // Because the loop goes over dst_lengths[], the source index may
1273 // be out of src_lengths[] bounds. In this case, within_src_bounds
1274 // is false.
1275 within_src_bounds &= i < src_lengths[axis];
1276 generate(axis - 1, dst_index, src_index, within_src_bounds);
1277 dst_index += dst_strides[axis];
1278 src_index += src_strides[axis];
1279 }
1280 }
1281 };
1282 generate(rank - 1, dst_start, src_start, true);
1283 }
1284
1285 // Copies the input data from a literal to a pre-allocated vector. The sizes
1286 // of the input and the transform do not need to match. For each axis of the
1287 // transform, any extra input values beyond the transform length are ignored.
1288 // Conversely, if the input does not contain enough elements along any axis,
1289 // the data is padded with zeroes.
1290 //
1291 // For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
1292 // where length_x is the size of the full transform along the X axis.
1293 //
1294 // The input literal may have a rank higher than the rank of the transform.
1295 // Passed-in input_index value points to the first element of the input
1296 // literal to be copied.
1297 //
1298 // Returns true if all values in the work data set are zeroes.
1299 //
1300 template <typename InputType>
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<ComplexType> data)1301 bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
1302 int64_t fft_size,
1303 const absl::Span<const int64> fft_lengths,
1304 const absl::Span<const int64> fft_strides,
1305 const absl::Span<const int64> input_lengths,
1306 const absl::Span<const int64> input_strides,
1307 absl::Span<ComplexType> data) {
1308 CHECK_GE(data.size(), fft_size);
1309
1310 const bool input_is_truncated = fft_type_ == FftType::IRFFT;
1311
1312 // Recursively visit each transform dimension to copy input values to the
1313 // working data set. The base case handles inputs along the X axis.
1314 bool input_is_zero = true;
1315 const InputType* input_data = input_literal.data<InputType>().data();
1316 auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
1317 bool within_src_bounds) {
1318 if (axis == 0) {
1319 // For IRFFT, the negative frequencies are only needed for the sweep
1320 // along the X axis, which is performed last. Leave this part of the
1321 // working set uninitialized until then.
1322 const int64_t length = fft_lengths[axis];
1323 const int64_t ub = input_is_truncated ? (length / 2) + 1 : length;
1324 for (int64_t i = 0; i < ub; i++) {
1325 ComplexType value = ComplexType(0);
1326 // Read input value only if the index is within bounds.
1327 if (within_src_bounds && i < input_lengths[axis]) {
1328 value = TypeConverter<ComplexType, InputType>::GetAs(
1329 input_data[src_index + i * input_strides[axis]]);
1330 input_is_zero &= value == ComplexType(0.0, 0.0);
1331 }
1332 data[dst_index + i * fft_strides[axis]] = value;
1333 }
1334 return true;
1335 }
1336 return false;
1337 };
1338 GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
1339 fft_rank_, 0, input_start, base_case);
1340 return input_is_zero;
1341 }
1342
1343 // Copies the result of the transform to the literal output. The sizes of the
1344 // transform and output must match.
1345 //
1346 // For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
1347 // the size of the full transform along the X axis (the most minor dimension).
1348 //
1349 // The output literal may have a rank higher than the rank of the transform.
1350 // Passed-in output_index value points to the first element of the output
1351 // literal to be filled in.
1352 //
1353 template <typename OutputType>
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1354 void CopyDataToOutput(const absl::Span<ComplexType> data,
1355 int64_t output_start,
1356 const absl::Span<const int64> fft_lengths,
1357 const absl::Span<const int64> fft_strides,
1358 const absl::Span<const int64> output_lengths,
1359 const absl::Span<const int64> output_strides,
1360 Literal* output_literal) {
1361 const bool output_is_truncated = fft_type_ == FftType::RFFT;
1362
1363 // Base case for recursive copy of the results to the output. The code
1364 // avoids making a recursive call for each output element by handling axis 0
1365 // in the loop (as opposed to making "axis < 0" to be the base case).
1366 OutputType* output_data = output_literal->data<OutputType>().data();
1367 auto base_case = [&](int64_t axis, int64_t dst_index, int64_t src_index,
1368 bool within_src_bounds) {
1369 if (axis == 0) {
1370 // Drop negative frequencies for RFFT.
1371 const int64_t length = fft_lengths[axis];
1372 const int64_t ub = output_is_truncated ? (length / 2) + 1 : length;
1373 for (int64_t i = 0; i < output_lengths[axis]; i++) {
1374 OutputType value = OutputType(0);
1375 // Read data only if the index is within bounds.
1376 if (within_src_bounds && i < ub) {
1377 value = TypeConverter<OutputType, ComplexType>::GetAs(
1378 data[src_index + i * fft_strides[axis]]);
1379 }
1380 output_data[dst_index + i * output_strides[axis]] = value;
1381 }
1382 return true;
1383 }
1384 return false;
1385 };
1386 GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
1387 fft_rank_, output_start, 0, base_case);
1388 }
1389
1390 // Determine the type to use with the CopyDataFromInput<> template above.
CopyDataFromInput(const Literal & input_literal,int64_t input_start,int64_t fft_size,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> input_lengths,const absl::Span<const int64> input_strides,absl::Span<ComplexType> data)1391 bool CopyDataFromInput(const Literal& input_literal, int64_t input_start,
1392 int64_t fft_size,
1393 const absl::Span<const int64> fft_lengths,
1394 const absl::Span<const int64> fft_strides,
1395 const absl::Span<const int64> input_lengths,
1396 const absl::Span<const int64> input_strides,
1397 absl::Span<ComplexType> data) {
1398 const bool input_is_float = fft_type_ == FftType::RFFT;
1399 if (input_is_float) {
1400 return CopyDataFromInput<float>(input_literal, input_start, fft_size,
1401 fft_lengths, fft_strides, input_lengths,
1402 input_strides, data);
1403 } else {
1404 return CopyDataFromInput<complex64>(input_literal, input_start, fft_size,
1405 fft_lengths, fft_strides,
1406 input_lengths, input_strides, data);
1407 }
1408 }
1409
1410 // Determine the type to use with the CopyDataToOutput<> template above.
CopyDataToOutput(const absl::Span<ComplexType> data,int64_t output_start,const absl::Span<const int64> fft_lengths,const absl::Span<const int64> fft_strides,const absl::Span<const int64> output_lengths,const absl::Span<const int64> output_strides,Literal * output_literal)1411 void CopyDataToOutput(const absl::Span<ComplexType> data,
1412 int64_t output_start,
1413 const absl::Span<const int64> fft_lengths,
1414 const absl::Span<const int64> fft_strides,
1415 const absl::Span<const int64> output_lengths,
1416 const absl::Span<const int64> output_strides,
1417 Literal* output_literal) {
1418 const bool output_is_float = fft_type_ == FftType::IRFFT;
1419 if (output_is_float) {
1420 CopyDataToOutput<float>(data, output_start, fft_lengths, fft_strides,
1421 output_lengths, output_strides, output_literal);
1422 } else {
1423 CopyDataToOutput<complex64>(data, output_start, fft_lengths, fft_strides,
1424 output_lengths, output_strides,
1425 output_literal);
1426 }
1427 }
1428
CheckParameters(const Shape & input_shape,const Shape & output_shape)1429 Status CheckParameters(const Shape& input_shape, const Shape& output_shape) {
1430 // Check FFT parameters.
1431 if (fft_rank_ <= 0) {
1432 return InvalidArgument("Zero or negative FFT rank.");
1433 }
1434 if (*absl::c_min_element(fft_lengths_) < 0) {
1435 return InvalidArgument("Negative FFT length.");
1436 }
1437
1438 // Check input-related values.
1439 TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
1440 if (!input_shape.IsArray()) {
1441 return Unimplemented("Only array input shapes are supported.");
1442 }
1443 auto input_elt_type = input_shape.element_type();
1444 if (fft_type_ == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
1445 return InvalidArgument("Invalid input type: %d, must be %d (float).",
1446 input_elt_type, PrimitiveType::F32);
1447 }
1448 if (fft_type_ != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
1449 return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
1450 input_elt_type, PrimitiveType::C64);
1451 }
1452 const int64_t input_rank = input_shape.rank();
1453 if (input_rank < fft_rank_) {
1454 return InvalidArgument("Input shape rank is smaller than FFT rank.");
1455 }
1456
1457 // Check output-related values.
1458 TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
1459 if (!output_shape.IsArray()) {
1460 return Unimplemented("Only array output shapes are supported.");
1461 }
1462 auto output_elt_type = output_shape.element_type();
1463 if (fft_type_ == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
1464 return InvalidArgument("Invalid output type: %d, must be %d (float).",
1465 output_elt_type, PrimitiveType::F32);
1466 }
1467 if (fft_type_ != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
1468 return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
1469 output_elt_type, PrimitiveType::C64);
1470 }
1471 const int64_t output_rank = output_shape.rank();
1472 if (output_rank < fft_rank_) {
1473 return InvalidArgument("Output shape rank is smaller than FFT rank.");
1474 }
1475
1476 // Consistency of input and output parameters.
1477 if (input_rank != output_rank) {
1478 return InvalidArgument(
1479 "Ranks of input shape and output shape do not match.");
1480 }
1481 for (int64_t dim = 0; dim < input_rank - fft_rank_; dim++) {
1482 if (ShapeUtil::GetDimension(input_shape, dim) !=
1483 ShapeUtil::GetDimension(output_shape, dim)) {
1484 return InvalidArgument(
1485 "Higher dimension lengths of input shape and output shape do not "
1486 "match.");
1487 }
1488 }
1489
1490 return Status::OK();
1491 }
1492
1493 private:
1494 const FftType fft_type_;
1495 const int64 fft_rank_;
1496 std::vector<int64> fft_lengths_;
1497 };
1498
1499 } // namespace
1500
HandleFft(HloInstruction * fft)1501 Status HloEvaluator::HandleFft(HloInstruction* fft) {
1502 const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
1503 Literal output_literal = Literal::CreateFromShape(fft->shape());
1504
1505 FftTransform<complex128> transform(fft);
1506 TF_RETURN_IF_ERROR(transform.ComputeFft(fft, input_literal, &output_literal));
1507 evaluated_[fft] = std::move(output_literal);
1508
1509 return Status::OK();
1510 }
1511
1512 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
1513 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputBatchIndices(const Shape & output_shape,const GatherDimensionNumbers & dim_numbers)1514 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
1515 const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
1516 int64_t output_rank = output_shape.dimensions_size();
1517 std::vector<int64> index_base(output_rank, 0);
1518 std::vector<int64> index_count;
1519 index_count.reserve(output_rank);
1520 for (int64_t i = 0; i < output_rank; i++) {
1521 bool is_output_batch_dim =
1522 !absl::c_binary_search(dim_numbers.offset_dims(), i);
1523 index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
1524 }
1525
1526 return {std::move(index_base), std::move(index_count),
1527 std::vector<int64>(output_rank, 1)};
1528 }
1529
1530 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
1531 // dimensions while keeping the rest of the output dimensions clamped to 0.
IterationSpaceForOutputOffsetIndices(int64_t output_rank,absl::Span<const int64> slice_sizes,const GatherDimensionNumbers & dim_numbers)1532 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
1533 int64_t output_rank, absl::Span<const int64> slice_sizes,
1534 const GatherDimensionNumbers& dim_numbers) {
1535 std::vector<int64> index_base(output_rank, 0);
1536 std::vector<int64> index_count(output_rank, 1);
1537 int64_t slice_sizes_idx = 0;
1538 for (int64_t i = 0; i < output_rank; i++) {
1539 bool is_output_window_dim =
1540 absl::c_binary_search(dim_numbers.offset_dims(), i);
1541 if (is_output_window_dim) {
1542 while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
1543 slice_sizes_idx)) {
1544 slice_sizes_idx++;
1545 }
1546 index_count[i] = slice_sizes[slice_sizes_idx++];
1547 }
1548 }
1549
1550 return {std::move(index_base), std::move(index_count),
1551 std::vector<int64>(output_rank, 1)};
1552 }
1553
1554 // This functor computes the contribution of start_indices to an input index
1555 // corresponding to an output index. That is, given an output index I, it picks
1556 // out the batch indices in I and uses them to look up a starting index, G, from
1557 // the start indices tensor, and expands G into the input space according to
1558 // start_index_map.
1559 class OutputBatchIndexToInputIndex {
1560 public:
1561 // The constructor does some setup work that is amortized across all
1562 // iterations.
OutputBatchIndexToInputIndex(const GatherDimensionNumbers * dim_numbers,const Shape & input_shape,const Shape & output_shape,const Literal * start_indices)1563 explicit OutputBatchIndexToInputIndex(
1564 const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
1565 const Shape& output_shape, const Literal* start_indices)
1566 : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
1567 for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
1568 output_dim_is_batch_dims_.push_back(
1569 !absl::c_binary_search(dim_numbers_.offset_dims(), i));
1570 }
1571
1572 for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
1573 int64_t index_of_input_dim_in_index_vector =
1574 std::distance(dim_numbers_.start_index_map().begin(),
1575 absl::c_find(dim_numbers_.start_index_map(), i));
1576 if (index_of_input_dim_in_index_vector ==
1577 dim_numbers_.start_index_map_size()) {
1578 input_dim_value_to_index_vector_.push_back(-1);
1579 } else {
1580 input_dim_value_to_index_vector_.push_back(
1581 index_of_input_dim_in_index_vector);
1582 }
1583 }
1584
1585 index_vector_index_.resize(start_indices_.shape().dimensions_size());
1586 input_index_.resize(input_shape.dimensions_size());
1587 int64_t index_vector_size =
1588 start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
1589 index_vector_.resize(index_vector_size);
1590 }
1591
1592 // Returns the contribution of start_indices to the input index corresponding
1593 // to output_index. See gather_inner_loop_body.
1594 //
1595 // This is conceptually a stateless transformation from output_index to the
1596 // gather input index, but:
1597 //
1598 // - Instead of allocating memory to represent the gather input index on
1599 // every invocation we reuse the same storage for the result
1600 // (input_index_), mutating it in place.
1601 // - Instead of allocating buffers for temporary values like
1602 // index_vector_index_ and index_vector on every invocation, we reuse the
1603 // same storage for all invocations.
1604 //
1605 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1606 StatusOr<absl::Span<const int64>> operator()(
1607 absl::Span<const int64> output_index) {
1608 PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
1609 TF_RETURN_IF_ERROR(FetchIndexVector());
1610 PropagateIndexVectorToInputIndex();
1611 return absl::Span<const int64>(input_index_);
1612 }
1613
1614 private:
1615 // Propagates the batch dimensions from the output index into
1616 // index_vector_index_ by mutating index_vector_index_ in place. Does not
1617 // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
1618 // we iterate over in FetchIndexVector.
PropagateOutputIndexGatherDimsToIndexVectorIndex(absl::Span<const int64> output_index)1619 void PropagateOutputIndexGatherDimsToIndexVectorIndex(
1620 absl::Span<const int64> output_index) {
1621 int64_t index_vector_index_i = 0;
1622 for (int64_t i = 0, e = output_index.size(); i < e; i++) {
1623 if (!output_dim_is_batch_dims_[i]) {
1624 continue;
1625 }
1626
1627 if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
1628 index_vector_index_i++;
1629 }
1630
1631 index_vector_index_[index_vector_index_i++] = output_index[i];
1632 }
1633 }
1634
1635 // Populates index_vector_ by iterating over start_indices_ according to
1636 // index_vector_index_.
FetchIndexVector()1637 Status FetchIndexVector() {
1638 int64_t index_vector_dim = dim_numbers_.index_vector_dim();
1639 for (int64_t i = 0, e = index_vector_.size(); i < e; i++) {
1640 index_vector_index_[index_vector_dim] = i;
1641 auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
1642 TF_RET_CHECK(start_index.has_value());
1643 index_vector_[i] = *start_index;
1644 }
1645 return Status::OK();
1646 }
1647
1648 // Populates input_index_.
PropagateIndexVectorToInputIndex()1649 void PropagateIndexVectorToInputIndex() {
1650 for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
1651 if (input_dim_value_to_index_vector_[i] != -1) {
1652 input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
1653 }
1654
1655 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1656 // remains 0, as set by the constructor.
1657 }
1658 }
1659
1660 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1661 // the input index from the index vector. See
1662 // PropagateIndexVectorToInputIndex.
1663 std::vector<int64> input_dim_value_to_index_vector_;
1664
1665 // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
1666 // dimension.
1667 std::vector<bool> output_dim_is_batch_dims_;
1668
1669 // The buffer into which we construct an index into start_indices_ to fetch
1670 // the index vector.
1671 std::vector<int64> index_vector_index_;
1672
1673 // The index vector fetched from start_indices_.
1674 std::vector<int64> index_vector_;
1675
1676 // The result computed by this functor. operator() returns a Span into
1677 // this vector.
1678 std::vector<int64> input_index_;
1679
1680 const GatherDimensionNumbers& dim_numbers_;
1681 const Literal& start_indices_;
1682 };
1683
1684 // This functor computes the contribution of the offset indices in an output
1685 // index to an input index. That is, given an output index I it picks out the
1686 // output offset indices in I and expands it into an index into the input shape.
1687 class OutputOffsetIndexToInputIndex {
1688 public:
1689 // The constructor does some setup work that is amortized across all
1690 // iterations.
OutputOffsetIndexToInputIndex(const GatherDimensionNumbers & dim_numbers,const Shape & input_shape,const Shape & output_shape)1691 explicit OutputOffsetIndexToInputIndex(
1692 const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
1693 const Shape& output_shape) {
1694 std::vector<int64> window_index_to_output_index;
1695 int64_t output_index_count = 0;
1696 for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
1697 if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1698 window_index_to_output_index.push_back(output_index_count++);
1699 } else {
1700 output_index_count++;
1701 }
1702 }
1703
1704 int64_t window_dim_count = 0;
1705 for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
1706 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1707 input_dim_value_to_output_index_.push_back(-1);
1708 } else {
1709 input_dim_value_to_output_index_.push_back(
1710 window_index_to_output_index[window_dim_count++]);
1711 }
1712 }
1713
1714 input_index_.resize(input_shape.dimensions_size());
1715 }
1716
1717 // Returns the contribution of the window indices to the input index
1718 // corresponding to output_index. See gather_inner_loop_body.
1719 //
1720 // This is conceptually a stateless transformation from output_index to the
1721 // window input index, but instead of allocating memory to represent the
1722 // gather input index on every invocation we reuse the same storage for the
1723 // result (input_index_), mutating it in place.
1724 //
1725 // This returns a Span into memory owned by the class.
operator ()(absl::Span<const int64> output_index)1726 StatusOr<absl::Span<const int64>> operator()(
1727 absl::Span<const int64> output_index) {
1728 PropagateOutputIndexWindowDimsToInputIndex(output_index);
1729 return absl::Span<const int64>(input_index_);
1730 }
1731
1732 // Returns for a given 'input_dim' the corresponding output dimension index,
1733 // or -1 if 'input_dim' is an elided window dimension.
input_dim_value_to_output_index(int64_t input_dim)1734 int64 input_dim_value_to_output_index(int64_t input_dim) {
1735 return input_dim_value_to_output_index_[input_dim];
1736 }
1737
1738 private:
1739 // Propagates window dimensions from the output index to input_index_ by
1740 // mutating input_index_ in place.
PropagateOutputIndexWindowDimsToInputIndex(absl::Span<const int64> output_index)1741 void PropagateOutputIndexWindowDimsToInputIndex(
1742 absl::Span<const int64> output_index) {
1743 for (int64_t i = 0, e = input_index_.size(); i < e; i++) {
1744 if (input_dim_value_to_output_index_[i] != -1) {
1745 input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
1746 }
1747
1748 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
1749 // remains 0, as set by the constructor.
1750 }
1751 }
1752
1753 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
1754 // the input index from the output index. See
1755 // PropagateOutputIndexWindowDimsToInputIndex.
1756 std::vector<int64> input_dim_value_to_output_index_;
1757
1758 // The result computed by this functor. operator() returns a Span into
1759 // this vector.
1760 std::vector<int64> input_index_;
1761 };
1762
1763 // Reshapes the gather indices input to have a trailing degenerate `1` dimension
1764 // if necessary. Hands over the ownership of the newly created literal (if
1765 // there is one) to `reshaped_start_indices`.
ReshapedGatherIndices(int64_t index_vector_dim,const Literal & start_indices,Literal * reshaped_start_indices)1766 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
1767 int64_t index_vector_dim, const Literal& start_indices,
1768 Literal* reshaped_start_indices) {
1769 if (start_indices.shape().dimensions_size() != index_vector_dim) {
1770 return std::cref(start_indices);
1771 }
1772
1773 std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
1774 start_indices.shape().dimensions().end());
1775 new_shape.push_back(1);
1776 TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
1777 start_indices.Reshape(new_shape));
1778 return std::cref(*reshaped_start_indices);
1779 }
1780
HandleGather(HloInstruction * gather)1781 Status HloEvaluator::HandleGather(HloInstruction* gather) {
1782 Literal result = Literal::CreateFromShape(gather->shape());
1783 const Shape& shape = gather->shape();
1784 const GatherDimensionNumbers& dim_numbers =
1785 gather->gather_dimension_numbers();
1786 const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
1787 Literal reshaped_start_indices;
1788 TF_ASSIGN_OR_RETURN(
1789 const Literal& start_indices,
1790 ReshapedGatherIndices(dim_numbers.index_vector_dim(),
1791 GetEvaluatedLiteralFor(gather->operand(1)),
1792 &reshaped_start_indices));
1793
1794 // We iterate over the gather dimensions in the output shape in an outer loop
1795 // nest, and iterate over the window dimensions in the output shape in an
1796 // inner loop nest.
1797
1798 ShapeUtil::IndexIterationSpace start_indices_iteration_space =
1799 IterationSpaceForOutputBatchIndices(shape, dim_numbers);
1800 ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
1801 IterationSpaceForOutputOffsetIndices(
1802 shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
1803
1804 // Scratch buffers that hold an index in the output shape and the
1805 // corresponding index in the input shape.
1806 std::vector<int64> input_index(operand.shape().dimensions_size());
1807 std::vector<int64> output_index(gather->shape().dimensions_size());
1808 std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
1809
1810 OutputBatchIndexToInputIndex output_batch_index_to_input_index(
1811 &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1812 /*output_shape=*/shape, &start_indices);
1813 OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
1814 gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
1815 /*output_shape=*/shape);
1816
1817 const Shape& operand_shape = operand.shape();
1818 if (ShapeUtil::IsZeroElementArray(operand_shape)) {
1819 evaluated_[gather] = std::move(result);
1820 return Status::OK();
1821 }
1822
1823 auto gather_inner_loop_body =
1824 [&](absl::Span<const int64> output_window_index,
1825 absl::Span<const int64> input_gather_index,
1826 absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1827 TF_ASSIGN_OR_RETURN(
1828 absl::Span<const int64> input_window_index,
1829 output_offset_index_to_input_index(output_window_index));
1830 for (int i = 0, e = output_index.size(); i < e; i++) {
1831 output_index[i] = output_gather_index[i] + output_window_index[i];
1832 DCHECK_LT(output_index[i], shape.dimensions(i));
1833 }
1834 for (int i = 0, e = input_gather_index.size(); i < e; i++) {
1835 int64_t output_dim =
1836 output_offset_index_to_input_index.input_dim_value_to_output_index(i);
1837 // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
1838 // we set the iteration index to 0, so for the purpose of the following
1839 // calculations we can consider the output dimension size to be 1.
1840 int64_t output_dim_size =
1841 output_dim == -1 ? 1 : shape.dimensions(output_dim);
1842 // Clamp the gather index so that the gather region fits in the operand.
1843 // input_index_clamped[i] = clamp(input_gather_index[i], 0,
1844 // operand_shape.dimensions(i) -
1845 // output_dim_size);
1846 input_index_clamped[i] =
1847 std::min(operand_shape.dimensions(i) - output_dim_size,
1848 std::max(int64{0}, input_gather_index[i]));
1849 }
1850 for (int i = 0, e = input_index.size(); i < e; i++) {
1851 input_index[i] = input_index_clamped[i] + input_window_index[i];
1852 DCHECK_GE(input_index[i], 0);
1853 DCHECK_LT(input_index[i], operand_shape.dimensions(i));
1854 }
1855 TF_RETURN_IF_ERROR(
1856 result.CopyElementFrom(operand, input_index, output_index));
1857 return true;
1858 };
1859
1860 auto gather_outer_loop_body =
1861 [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
1862 TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
1863 output_batch_index_to_input_index(output_gather_index));
1864 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1865 shape, offset_indices_iteration_space,
1866 std::bind(gather_inner_loop_body, std::placeholders::_1,
1867 input_gather_index, output_gather_index)));
1868 return true;
1869 };
1870
1871 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
1872 shape, start_indices_iteration_space, gather_outer_loop_body));
1873 evaluated_[gather] = std::move(result);
1874 return Status::OK();
1875 }
1876
HandleBroadcast(HloInstruction * broadcast)1877 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
1878 const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
1879
1880 TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
1881 << "broadcast dimensions is of size: " << broadcast->dimensions().size()
1882 << " and rank of operand_to_broadcast is: " << operand.shape().rank();
1883 // Checks that operand's dimensions are the same as the broadcast's
1884 // dimensions along the dimensions to be broadcasted.
1885 for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
1886 auto operand_dim_size = operand.shape().dimensions(i);
1887 auto broadcast_dim_size =
1888 broadcast->shape().dimensions(broadcast->dimensions(i));
1889 TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
1890 "Operand dimension %d is broadcast to output dimension %d, but the "
1891 "sizes of these two dims do not match (%d vs %d): %s",
1892 i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
1893 broadcast->ToString());
1894 }
1895
1896 TF_ASSIGN_OR_RETURN(
1897 evaluated_[broadcast],
1898 operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
1899
1900 return Status::OK();
1901 }
1902
HandleAfterAll(HloInstruction * after_all)1903 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
1904 evaluated_[after_all] = LiteralUtil::CreateToken();
1905 return Status::OK();
1906 }
1907
HandleAddDependency(HloInstruction * add_dependency)1908 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
1909 // AddDedendency just forwards its zero-th operand.
1910 evaluated_[add_dependency] =
1911 GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
1912 return Status::OK();
1913 }
1914
HandleGetTupleElement(HloInstruction * get_tuple_element)1915 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
1916 const auto result_shape = get_tuple_element->shape();
1917 const int64_t index = get_tuple_element->tuple_index();
1918
1919 auto operand = get_tuple_element->operand(0);
1920 TF_ASSIGN_OR_RETURN(
1921 auto inferred_return_shape,
1922 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
1923 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1924 << "return shape set to: " << ShapeUtil::HumanString(result_shape)
1925 << " but is inferred to be: "
1926 << ShapeUtil::HumanString(inferred_return_shape);
1927
1928 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1929
1930 evaluated_[get_tuple_element] =
1931 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
1932 return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
1933 /*dest_shape_index=*/{},
1934 /*src_shape_index=*/{index});
1935 }
1936
HandleCopy(HloInstruction * copy)1937 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
1938 TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
1939 evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
1940 return Status::OK();
1941 }
1942
HandleCopyStart(HloInstruction * copy_start)1943 Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
1944 if (copy_start->user_count() != 1 ||
1945 copy_start->users().at(0)->opcode() != HloOpcode::kCopyDone) {
1946 return tensorflow::errors::FailedPrecondition(
1947 "Cannot evaluate a kCopyStart that doesn't have a single kCopyDone "
1948 "user.");
1949 }
1950
1951 // The context in index {2} is undefined, but since we can't represent
1952 // undefined values using a Literal, we just use 0. This should be safe though
1953 // since we ensure that the only user of a kCopyStart is a kCopyDone which
1954 // consumes the context. Also note that MakeTuple copies its arguments, so
1955 // this is memory-safe.
1956 const Literal context_literal = LiteralUtil::CreateR0<uint32>(0);
1957 evaluated_[copy_start] = LiteralUtil::MakeTuple(
1958 {&GetEvaluatedLiteralFor(copy_start->operand(0)),
1959 &GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
1960 return Status::OK();
1961 }
1962
HandleCopyDone(HloInstruction * copy_done)1963 Status HloEvaluator::HandleCopyDone(HloInstruction* copy_done) {
1964 const HloInstruction* operand = copy_done->operand(0);
1965 if (operand->opcode() != HloOpcode::kCopyStart) {
1966 return tensorflow::errors::FailedPrecondition(
1967 "Cannot evaluate a kCopyDone that doesn't have a kCopyStart as "
1968 "operand.");
1969 }
1970
1971 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
1972
1973 evaluated_[copy_done] =
1974 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), /*index=*/0));
1975 TF_RETURN_IF_ERROR(evaluated_[copy_done].CopyFrom(operand_tuple_literal,
1976 /*dest_shape_index=*/{},
1977 /*src_shape_index=*/{0}));
1978 return Status::OK();
1979 }
1980
HandleCall(HloInstruction * call)1981 Status HloEvaluator::HandleCall(HloInstruction* call) {
1982 auto* computation = call->to_apply();
1983 auto operands = call->operands();
1984
1985 std::vector<const Literal*> arg_literals;
1986 arg_literals.reserve(operands.size());
1987 for (auto operand : operands) {
1988 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
1989 arg_literals.push_back(&arg_literal);
1990 }
1991
1992 HloEvaluator embedded_evaluator;
1993 embedded_evaluator.set_dynamic_dimension_inference(
1994 dynamic_dimension_inference_);
1995 TF_ASSIGN_OR_RETURN(Literal result,
1996 embedded_evaluator.Evaluate(*computation, arg_literals));
1997
1998 evaluated_[call] = std::move(result);
1999 return Status::OK();
2000 }
2001
HandleFusion(HloInstruction * fusion)2002 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
2003 HloModuleConfig config;
2004 // Attach cloned computation to an empty HLO module so the existing ones are
2005 // not modified.
2006 HloModule empty_hlo_module("EmptyModuleForFusion", config);
2007 HloCloneContext context(&empty_hlo_module);
2008 auto cloned_fused_computation =
2009 fusion->fused_instructions_computation()->Clone(
2010 /*suffix=*/"clone_with_layout", &context);
2011 for (auto* instruction : cloned_fused_computation->instructions()) {
2012 if (!LayoutUtil::HasLayout(instruction->shape())) {
2013 LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
2014 }
2015 }
2016 auto readded_computation =
2017 empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
2018
2019 auto operands = fusion->operands();
2020 std::vector<const Literal*> arg_literals;
2021 arg_literals.reserve(operands.size());
2022 for (auto operand : operands) {
2023 const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
2024 arg_literals.push_back(&arg_literal);
2025 }
2026
2027 HloEvaluator embedded_evaluator;
2028 embedded_evaluator.set_dynamic_dimension_inference(
2029 dynamic_dimension_inference_);
2030 TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
2031 *readded_computation, arg_literals));
2032
2033 evaluated_[fusion] = std::move(result);
2034 return Status::OK();
2035 }
2036
HandleConditional(HloInstruction * conditional)2037 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
2038 const auto& branch_index_literal =
2039 GetEvaluatedLiteralFor(conditional->operand(0));
2040 int branch_index;
2041 if (conditional->operand(0)->shape().element_type() == PRED) {
2042 branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
2043 } else {
2044 branch_index = branch_index_literal.Get<int32>({});
2045 if (branch_index < 0 || branch_index >= conditional->branch_count()) {
2046 branch_index = conditional->branch_count() - 1;
2047 }
2048 }
2049 const auto& branch_computation_arg =
2050 GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
2051
2052 HloEvaluator embedded_evaluator;
2053 embedded_evaluator.set_dynamic_dimension_inference(
2054 dynamic_dimension_inference_);
2055 TF_ASSIGN_OR_RETURN(Literal result,
2056 embedded_evaluator.Evaluate(
2057 *conditional->branch_computation(branch_index),
2058 {&branch_computation_arg}));
2059
2060 evaluated_[conditional] = std::move(result);
2061 return Status::OK();
2062 }
2063
HandleSelect(HloInstruction * select)2064 Status HloEvaluator::HandleSelect(HloInstruction* select) {
2065 const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
2066 const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
2067 const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
2068
2069 // If predicate is of scalar type, no element-wise selection would be needed.
2070 if (ShapeUtil::IsScalar(pred.shape())) {
2071 if (pred.Get<bool>({})) {
2072 evaluated_[select] = on_true.Clone();
2073 } else {
2074 evaluated_[select] = on_false.Clone();
2075 }
2076 return Status::OK();
2077 }
2078
2079 return DefaultAction(select);
2080 }
2081
HandleTupleSelect(HloInstruction * tuple_select)2082 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
2083 const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
2084 const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
2085 const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
2086
2087 if (pred.Get<bool>({})) {
2088 evaluated_[tuple_select] = on_true.Clone();
2089 } else {
2090 evaluated_[tuple_select] = on_false.Clone();
2091 }
2092 return Status::OK();
2093 }
2094
HandleWhile(HloInstruction * while_hlo)2095 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
2096 HloComputation* cond_comp = while_hlo->while_condition();
2097 HloComputation* body_comp = while_hlo->while_body();
2098 // Initialize the loop carried valued with the input to the While instruction.
2099 auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
2100 bool keep_going = true;
2101 int64_t iteration_count = 0;
2102 HloEvaluator cond_evaluator(max_loop_iterations_);
2103 cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
2104 HloEvaluator loop_body_evaluator(max_loop_iterations_);
2105 loop_body_evaluator.set_dynamic_dimension_inference(
2106 dynamic_dimension_inference_);
2107 while (keep_going) {
2108 if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
2109 return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
2110 while_hlo->name(), max_loop_iterations_);
2111 }
2112 TF_ASSIGN_OR_RETURN(auto cond_val,
2113 cond_evaluator.Evaluate(*cond_comp, {&lcv}));
2114 keep_going = cond_val.GetFirstElement<bool>();
2115 if (keep_going) {
2116 TF_ASSIGN_OR_RETURN(auto body_val,
2117 loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
2118 VLOG(3) << "Loop iteration result: " << body_val.ToString();
2119 lcv = std::move(body_val);
2120 cond_evaluator.ResetVisitStates();
2121 loop_body_evaluator.ResetVisitStates();
2122 }
2123 }
2124 evaluated_[while_hlo] = std::move(lcv);
2125 return Status::OK();
2126 }
2127
2128 namespace {
2129 template <typename NativeT>
ExtractLiteralFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar)2130 Literal ExtractLiteralFromIndexPositions(const Literal& from,
2131 absl::Span<int64 const> indices,
2132 bool extract_as_scalar) {
2133 if (extract_as_scalar) {
2134 return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
2135 }
2136 // We use a InlinedVector here because we need to convert it to an
2137 // absl::Span later, and this would not work with std::vector<bool>.
2138 absl::InlinedVector<NativeT, 10> values;
2139 for (int64_t index : indices) {
2140 values.push_back(from.Get<NativeT>({index}));
2141 }
2142 return LiteralUtil::CreateR1<NativeT>(values);
2143 }
2144
ExtractFromIndexPositions(const Literal & from,absl::Span<int64 const> indices,bool extract_as_scalar=false)2145 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
2146 absl::Span<int64 const> indices,
2147 bool extract_as_scalar = false) {
2148 if (extract_as_scalar) {
2149 CHECK_EQ(indices.size(), 1);
2150 }
2151 PrimitiveType type = from.shape().element_type();
2152 switch (type) {
2153 case PRED: {
2154 return ExtractLiteralFromIndexPositions<bool>(from, indices,
2155 extract_as_scalar);
2156 }
2157 case U8: {
2158 return ExtractLiteralFromIndexPositions<uint8>(from, indices,
2159 extract_as_scalar);
2160 }
2161 case S8: {
2162 return ExtractLiteralFromIndexPositions<int8>(from, indices,
2163 extract_as_scalar);
2164 }
2165 case BF16: {
2166 return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
2167 extract_as_scalar);
2168 }
2169 case F16: {
2170 return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
2171 extract_as_scalar);
2172 }
2173 case U16: {
2174 return ExtractLiteralFromIndexPositions<uint16>(from, indices,
2175 extract_as_scalar);
2176 }
2177 case S16: {
2178 return ExtractLiteralFromIndexPositions<int16>(from, indices,
2179 extract_as_scalar);
2180 }
2181 case F32: {
2182 return ExtractLiteralFromIndexPositions<float>(from, indices,
2183 extract_as_scalar);
2184 }
2185 case U32: {
2186 return ExtractLiteralFromIndexPositions<uint32>(from, indices,
2187 extract_as_scalar);
2188 }
2189 case S32: {
2190 return ExtractLiteralFromIndexPositions<int32>(from, indices,
2191 extract_as_scalar);
2192 }
2193 case F64: {
2194 return ExtractLiteralFromIndexPositions<double>(from, indices,
2195 extract_as_scalar);
2196 }
2197 case C64: {
2198 return ExtractLiteralFromIndexPositions<std::complex<float>>(
2199 from, indices, extract_as_scalar);
2200 }
2201 case U64: {
2202 return ExtractLiteralFromIndexPositions<uint64>(from, indices,
2203 extract_as_scalar);
2204 }
2205 case S64: {
2206 return ExtractLiteralFromIndexPositions<int64>(from, indices,
2207 extract_as_scalar);
2208 }
2209 case C128: {
2210 return ExtractLiteralFromIndexPositions<std::complex<double>>(
2211 from, indices, extract_as_scalar);
2212 }
2213 default:
2214 return InvalidArgument("Unsupported type for Sort: %s",
2215 PrimitiveType_Name(type));
2216 }
2217 }
2218 } // namespace
2219
HandleSort(HloInstruction * sort)2220 Status HloEvaluator::HandleSort(HloInstruction* sort) {
2221 TF_RET_CHECK(sort->operand_count() >= 1)
2222 << "Expected at least 1 operand for sort";
2223 for (int64_t i = 1; i < sort->operand_count(); ++i) {
2224 TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
2225 sort->operand(i)->shape()))
2226 << "All Sort operands must have the same dimensions";
2227 }
2228
2229 if (VLOG_IS_ON(3)) {
2230 for (int64_t i = 0; i < sort->operand_count(); ++i) {
2231 VLOG(3) << "HandleSort operand " << i << " literal: "
2232 << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
2233 }
2234 }
2235 Shape key_shape = sort->operand(0)->shape();
2236 auto rank = key_shape.rank();
2237 std::vector<Literal> result_literals;
2238 result_literals.reserve(sort->operand_count());
2239 for (int64_t i = 0; i < sort->operand_count(); ++i) {
2240 result_literals.emplace_back(sort->operand(i)->shape());
2241 }
2242 std::vector<int64> zero_base(rank, 0);
2243 std::vector<int64> increment(rank, 1);
2244 int64_t sort_dim = sort->dimensions(0);
2245 int64_t sort_dim_elements = key_shape.dimensions(sort_dim);
2246 increment[sort_dim] = sort_dim_elements;
2247 HloEvaluator embedded_evaluator(max_loop_iterations_);
2248 // Iterate through each dimension except 'sort_dim'.
2249 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2250 key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
2251 [&](absl::Span<const int64> indices) -> StatusOr<bool> {
2252 // Extract a slice from each operand literal that corresponds to
2253 // exactly the row in dimension 'sort_dim'.
2254 std::vector<int64> limit_indices(indices.begin(), indices.end());
2255 absl::c_for_each(limit_indices, [](int64& index) { ++index; });
2256 limit_indices[sort_dim] = sort_dim_elements;
2257 std::vector<Literal> literals_to_sort;
2258 literals_to_sort.reserve(sort->operand_count());
2259 for (int64_t i = 0; i < sort->operand_count(); ++i) {
2260 TF_ASSIGN_OR_RETURN(auto literal_to_sort,
2261 GetEvaluatedLiteralFor(sort->operand(i))
2262 .Slice(indices, limit_indices)
2263 .Reshape({sort_dim_elements}));
2264 literals_to_sort.push_back(std::move(literal_to_sort));
2265 }
2266 std::vector<int64> indices_to_sort(sort_dim_elements);
2267 std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
2268 Status compare_status = Status::OK();
2269 auto comparator = [sort, &compare_status, &embedded_evaluator,
2270 &literals_to_sort](int64_t a, int64_t b) {
2271 std::vector<Literal> literals;
2272 literals.reserve(2 * sort->operand_count());
2273 for (int64_t i = 0; i < sort->operand_count(); ++i) {
2274 auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
2275 /*extract_as_scalar=*/true);
2276 if (!lhs.ok()) {
2277 compare_status = lhs.status();
2278 return false;
2279 }
2280 literals.push_back(std::move(lhs.ValueOrDie()));
2281 auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
2282 /*extract_as_scalar=*/true);
2283 if (!rhs.ok()) {
2284 compare_status = rhs.status();
2285 return false;
2286 }
2287 literals.push_back(std::move(rhs.ValueOrDie()));
2288 }
2289 std::vector<const Literal*> literal_ptrs;
2290 absl::c_transform(literals, std::back_inserter(literal_ptrs),
2291 [](const Literal& literal) { return &literal; });
2292
2293 auto computed_result =
2294 embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
2295 // Clear visit states so that we can use the evaluator again
2296 // on the same computation.
2297 embedded_evaluator.ResetVisitStates();
2298 if (!computed_result.ok()) {
2299 compare_status = computed_result.status();
2300 return false;
2301 }
2302 return computed_result.ValueOrDie().Get<bool>({});
2303 };
2304 if (Cast<HloSortInstruction>(sort)->is_stable()) {
2305 std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
2306 comparator);
2307 } else {
2308 std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
2309 }
2310 if (!compare_status.ok()) {
2311 return compare_status;
2312 }
2313 std::vector<int64> slice_dimensions(rank, 1);
2314 slice_dimensions[sort_dim] = sort_dim_elements;
2315 std::vector<int64> start_indices(rank, 0);
2316 for (int64_t i = 0; i < sort->operand_count(); ++i) {
2317 TF_ASSIGN_OR_RETURN(
2318 Literal sorted_literal,
2319 ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
2320 TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
2321 sorted_literal.Reshape(slice_dimensions));
2322 TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
2323 sorted_literal_reshaped, start_indices, indices,
2324 slice_dimensions));
2325 }
2326 return true;
2327 }));
2328
2329 if (sort->operand_count() == 1) {
2330 evaluated_[sort] = std::move(result_literals[0]);
2331 } else {
2332 std::vector<const Literal*> literal_ptrs;
2333 absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
2334 [](const Literal& literal) { return &literal; });
2335
2336 Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
2337 VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
2338
2339 evaluated_[sort] = std::move(result_tuple);
2340 }
2341 return Status::OK();
2342 }
2343
IsScalarAdd(HloComputation * computation)2344 static bool IsScalarAdd(HloComputation* computation) {
2345 HloInstruction* instruction = computation->root_instruction();
2346 if (instruction->opcode() == HloOpcode::kAdd &&
2347 computation->num_parameters() == 2) {
2348 const HloInstruction* lhs = instruction->operand(0);
2349 const HloInstruction* rhs = instruction->operand(1);
2350 return lhs->opcode() == HloOpcode::kParameter &&
2351 ShapeUtil::IsScalar(lhs->shape()) &&
2352 rhs->opcode() == HloOpcode::kParameter &&
2353 ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
2354 }
2355 return false;
2356 }
2357
2358 // Run a single step of an inner loop while running reduction, which applies
2359 // the user-provided computation on the accumulator and the output element
2360 // (until the reduction is completed, the output element is also used as
2361 // an accumulator).
PerformReductionStep(bool is_tuple,absl::Span<const int64> input_index,absl::Span<const int64> output_index,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * computation,HloEvaluator * embedded_evaluator)2362 static StatusOr<bool> PerformReductionStep(
2363 bool is_tuple, absl::Span<const int64> input_index,
2364 absl::Span<const int64> output_index,
2365 absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2366 HloComputation* computation, HloEvaluator* embedded_evaluator) {
2367 int num_args = results.size();
2368
2369 absl::InlinedVector<Literal, 1> arg_values;
2370 arg_values.reserve(num_args);
2371 absl::InlinedVector<Literal, 1> accumulators;
2372 accumulators.reserve(num_args);
2373 for (int64_t i = 0; i < num_args; ++i) {
2374 arg_values.emplace_back(
2375 ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2376 accumulators.emplace_back(
2377 ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
2378
2379 TF_RETURN_IF_ERROR(
2380 arg_values[i].CopyElementFrom(*input_args[i], input_index, {}));
2381 TF_RETURN_IF_ERROR(
2382 accumulators[i].CopyElementFrom(results[i], output_index, {}));
2383 }
2384
2385 // Evaluate computation with specified literal operands.
2386 absl::InlinedVector<Literal*, 2> embedded_operands;
2387 for (Literal& accumulator : accumulators) {
2388 embedded_operands.push_back(&accumulator);
2389 }
2390 for (Literal& local_input : arg_values) {
2391 embedded_operands.push_back(&local_input);
2392 }
2393
2394 TF_ASSIGN_OR_RETURN(
2395 Literal computed_result,
2396 embedded_evaluator->Evaluate(*computation, embedded_operands));
2397
2398 // Clear visit states so that we can use the evaluator again on the same
2399 // computation.
2400 embedded_evaluator->ResetVisitStates();
2401
2402 if (is_tuple) {
2403 std::vector<Literal> computed_results = computed_result.DecomposeTuple();
2404 for (int64_t i = 0; i < num_args; ++i) {
2405 TF_RETURN_IF_ERROR(
2406 results[i].CopyElementFrom(computed_results[i], {}, output_index));
2407 }
2408 } else {
2409 TF_RETURN_IF_ERROR(
2410 results[0].CopyElementFrom(computed_result, {}, output_index));
2411 }
2412
2413 return true;
2414 }
2415
GenerateReduceOutputElement(bool is_tuple,absl::Span<const int64> output_index,absl::Span<const Literal * const> init_values,absl::Span<const Literal * const> input_args,absl::Span<Literal> results,HloComputation * function,HloEvaluator * embedded_evaluator,absl::Span<const int64> arg_dim_steps,absl::Span<const int64> arg_dim_counts,absl::Span<const int64> result_to_arg_index)2416 static StatusOr<bool> GenerateReduceOutputElement(
2417 bool is_tuple, absl::Span<const int64> output_index,
2418
2419 absl::Span<const Literal* const> init_values,
2420 absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
2421
2422 HloComputation* function, HloEvaluator* embedded_evaluator,
2423
2424 absl::Span<const int64> arg_dim_steps,
2425 absl::Span<const int64> arg_dim_counts,
2426 absl::Span<const int64> result_to_arg_index) {
2427 bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) &&
2428 IsScalarAdd(function) && !is_tuple;
2429
2430 const Shape& arg_shape = input_args[0]->shape();
2431 absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2432 std::vector<int64> base(arg_dimensions.size());
2433 for (int64_t i = 0; i < output_index.size(); ++i) {
2434 base[result_to_arg_index[i]] = output_index[i];
2435 }
2436
2437 for (int64_t i = 0; i < results.size(); ++i) {
2438 TF_RETURN_IF_ERROR(
2439 results[i].CopyElementFrom(*init_values[i], {}, output_index));
2440 }
2441
2442 if (use_fast_add) {
2443 double computed_result = *init_values[0]->GetAsDouble({});
2444 auto reduction_step =
2445 [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
2446 double argument = *input_args[0]->GetAsDouble(input_index);
2447 computed_result += argument;
2448 return true;
2449 };
2450 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2451 arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step));
2452 TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result));
2453 return true;
2454 }
2455
2456 // Iterates only over reduced shape, as counts and steps are set to zero
2457 // for all non-reduced dimensions.
2458 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2459 arg_shape, base, arg_dim_counts, arg_dim_steps,
2460 [&](absl::Span<const int64> input_index) {
2461 return PerformReductionStep(is_tuple, input_index, output_index,
2462 input_args, results, function,
2463 embedded_evaluator);
2464 }));
2465 return true;
2466 }
2467
HandleReduce(HloInstruction * instr)2468 Status HloEvaluator::HandleReduce(HloInstruction* instr) {
2469 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
2470 int64_t num_args = reduce->inputs().size();
2471 absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
2472 HloComputation* function = reduce->to_apply();
2473
2474 absl::InlinedVector<const Shape*, 1> operand_shapes;
2475 for (const HloInstruction* operand : reduce->operands()) {
2476 operand_shapes.push_back(&operand->shape());
2477 }
2478 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
2479 ShapeInference::InferReduceShape(
2480 operand_shapes, dimensions_to_reduce,
2481 /*to_apply=*/function->ComputeProgramShape()));
2482 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(),
2483 inferred_return_shape))
2484 << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
2485 << " but is inferred to be: "
2486 << ShapeUtil::HumanString(inferred_return_shape);
2487
2488 absl::InlinedVector<const Literal*, 1> input_args(num_args);
2489 absl::InlinedVector<const Literal*, 1> init_values(num_args);
2490 for (int64_t i = 0; i < num_args; ++i) {
2491 input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]);
2492 VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString();
2493 init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]);
2494 VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString();
2495 TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape()));
2496 }
2497
2498 // All args and results have the same dimensions, so pick an arbitrary one.
2499 const Shape& arg_shape = input_args[0]->shape();
2500 const Shape& out_shape = inferred_return_shape;
2501 bool is_tuple = out_shape.IsTuple();
2502 const Shape& output_shape = inferred_return_shape.IsTuple()
2503 ? inferred_return_shape.tuple_shapes(0)
2504 : inferred_return_shape;
2505
2506 absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
2507
2508 // All increments are set to 0.
2509 std::vector<int64> arg_dim_steps(arg_dimensions.size());
2510
2511 // All counts are set to 0.
2512 std::vector<int64> arg_dim_counts(arg_dimensions.size());
2513
2514 // Set steps and counts for reduced dimensions.
2515 // This avoids iterating over non-reduced dimensions, as their step
2516 // and count is set to zero.
2517 for (const int64_t dim : dimensions_to_reduce) {
2518 arg_dim_steps[dim] = 1;
2519 arg_dim_counts[dim] = arg_dimensions[dim];
2520 }
2521
2522 // Map each dimension in the result to a dimension in arg that isn't
2523 // being reduced.
2524 std::vector<int64> result_to_arg_index;
2525 for (int64_t i = 0; i < arg_dimensions.size(); ++i) {
2526 if (arg_dim_steps[i] == 0) {
2527 result_to_arg_index.push_back(i);
2528 }
2529 }
2530
2531 HloEvaluator embedded_evaluator(max_loop_iterations_);
2532 absl::InlinedVector<Literal, 1> results(num_args);
2533 for (int64_t i = 0; i < num_args; ++i) {
2534 results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape);
2535 }
2536
2537 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
2538 output_shape, [&](absl::Span<const int64> output_index) {
2539 return GenerateReduceOutputElement(
2540 is_tuple, output_index, init_values, input_args,
2541 absl::Span<Literal>(results), function, &embedded_evaluator,
2542 arg_dim_steps, arg_dim_counts, result_to_arg_index);
2543 }));
2544
2545 if (is_tuple) {
2546 Literal tuple_result(inferred_return_shape);
2547 for (int64_t i = 0; i < num_args; ++i) {
2548 TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
2549 }
2550 evaluated_[reduce] = std::move(tuple_result);
2551 } else {
2552 CHECK_EQ(results.size(), 1);
2553 evaluated_[reduce] = std::move(results[0]);
2554 }
2555 if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) {
2556 TF_ASSIGN_OR_RETURN(evaluated_[reduce],
2557 evaluated_[reduce].ConvertToShape(reduce->shape()));
2558 }
2559 return Status::OK();
2560 }
2561
HandleReduceWindow(HloInstruction * hlo)2562 Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) {
2563 // Here we delegate the handling to the typed visitor class, instantiated by
2564 // using the type of the first input of ReduceWindow. The support for the
2565 // variadic case inside the typed_visitor is made to not use the template
2566 // parameter so it doesn't really matter which type is used to instantiate it
2567 // here. We choose not to move the implementation for handle ReduceWindow
2568 // from the typed visitor to here because we need to reuse the
2569 // IterateThroughWindow method, which is defined and only avaiable inside the
2570 // typed visitor.
2571 if (hlo->shape().IsTuple()) {
2572 return hlo->Visit(
2573 typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get());
2574 } else {
2575 return DefaultAction(hlo);
2576 }
2577 }
2578
HandleCustomCall(HloInstruction * custom_call)2579 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
2580 if (!custom_call_handler_) {
2581 // No handler is registered; this means custom-calls are not allowed.
2582 return DefaultAction(custom_call);
2583 }
2584
2585 // Evaluate input operands so the handler has access to the operand data.
2586 std::vector<const Literal*> operands;
2587 operands.reserve(custom_call->operand_count());
2588 for (const HloInstruction* operand : custom_call->operands()) {
2589 operands.push_back(&GetEvaluatedLiteralFor(operand));
2590 }
2591
2592 // Synchronously issue the handler to populate the instruction output literal.
2593 TF_ASSIGN_OR_RETURN(
2594 auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
2595
2596 evaluated_[custom_call] = std::move(output);
2597 return Status::OK();
2598 }
2599
Preprocess(HloInstruction * hlo)2600 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
2601 VLOG(2) << "About to visit HLO: " << hlo->ToString();
2602 return ShapeUtil::ValidateShape(hlo->shape());
2603 }
2604
Postprocess(HloInstruction * hlo)2605 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
2606 VLOG(2) << "Finished visiting " << hlo->ToString()
2607 << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
2608 // Out of convenience the literal may have been produced with a different
2609 // layout. Relayout as indicated by the HLO instruction.
2610 if (!Layout::Equal().MinorToMajorOnly()(
2611 GetEvaluatedLiteralFor(hlo).shape().layout(),
2612 hlo->shape().layout())) {
2613 evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
2614 }
2615 return Status::OK();
2616 }
2617
2618 namespace {
2619 template <typename T>
MatmulArray2DImpl(const Array2D<T> & lhs,const Array2D<T> & rhs,const std::function<void (const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)> & impl_fn)2620 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
2621 const Array2D<T>& lhs, const Array2D<T>& rhs,
2622 const std::function<void(const void* run_options_ptr, T* out, T* lhs,
2623 T* rhs, int64_t m, int64_t n, int64_t k,
2624 int32_t transpose_lhs, int32_t transpose_rhs)>&
2625 impl_fn) {
2626 CHECK_EQ(lhs.width(), rhs.height());
2627 int m = lhs.height();
2628 int n = rhs.width();
2629 int k = lhs.width();
2630 auto result = absl::make_unique<Array2D<T>>(m, n);
2631 // Because Eigen is a header-oriented library, make sure that the Eigen code
2632 // is the same as the code used by the CPU backend (otherwise the linker will
2633 // randomly pick *some* definition).
2634 impl_fn(
2635 /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
2636 k,
2637 /*transpose_lhs=*/0,
2638 /*transpose_rhs=*/0);
2639 return result;
2640 }
2641 } // namespace
2642
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)2643 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
2644 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
2645 return MatmulArray2DImpl<Eigen::half>(
2646 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
2647 }
2648
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)2649 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
2650 const Array2D<float>& lhs, const Array2D<float>& rhs) {
2651 return MatmulArray2DImpl<float>(
2652 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
2653 }
2654
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)2655 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
2656 const Array2D<double>& lhs, const Array2D<double>& rhs) {
2657 return MatmulArray2DImpl<double>(
2658 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
2659 }
2660
MatmulArray2D(const Array2D<std::complex<float>> & lhs,const Array2D<std::complex<float>> & rhs)2661 std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
2662 const Array2D<std::complex<float>>& lhs,
2663 const Array2D<std::complex<float>>& rhs) {
2664 return MatmulArray2DImpl<std::complex<float>>(
2665 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
2666 }
2667
MatmulArray2D(const Array2D<std::complex<double>> & lhs,const Array2D<std::complex<double>> & rhs)2668 std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
2669 const Array2D<std::complex<double>>& lhs,
2670 const Array2D<std::complex<double>>& rhs) {
2671 return MatmulArray2DImpl<std::complex<double>>(
2672 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
2673 }
2674
MatmulArray2D(const Array2D<int32> & lhs,const Array2D<int32> & rhs)2675 std::unique_ptr<Array2D<int32>> HloEvaluator::MatmulArray2D(
2676 const Array2D<int32>& lhs, const Array2D<int32>& rhs) {
2677 return MatmulArray2DImpl<int32>(
2678 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulS32);
2679 }
2680
2681 } // namespace xla
2682