1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdlib>
20 #include <functional>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24 #include <vector>
25
26 #include "tensorflow/compiler/xla/index_util.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/ptr_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_query.h"
35 #include "tensorflow/compiler/xla/service/shape_inference.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/window_util.h"
40 #include "tensorflow/core/lib/core/bitmap.h"
41 #include "tensorflow/core/lib/core/casts.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/optional.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51
52 namespace {
53
54 template <typename T>
55 struct is_complex_t : public std::false_type {};
56
57 template <>
58 struct is_complex_t<complex64> : public std::true_type {};
59
60 template <typename OperandT>
Compare(const Shape & shape,HloOpcode opcode,const Literal & lhs_literal,const Literal & rhs_literal)61 StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
62 const Literal& lhs_literal,
63 const Literal& rhs_literal) {
64 std::function<bool(OperandT, OperandT)> compare_op;
65 switch (opcode) {
66 case HloOpcode::kEq:
67 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
68 return lhs_el == rhs_el;
69 };
70 break;
71 case HloOpcode::kNe:
72 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
73 return lhs_el != rhs_el;
74 };
75 break;
76 case HloOpcode::kGe:
77 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
78 return lhs_el >= rhs_el;
79 };
80 break;
81 case HloOpcode::kGt:
82 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
83 return lhs_el > rhs_el;
84 };
85 break;
86 case HloOpcode::kLe:
87 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
88 return lhs_el <= rhs_el;
89 };
90 break;
91 case HloOpcode::kLt:
92 compare_op = [](OperandT lhs_el, OperandT rhs_el) {
93 return lhs_el < rhs_el;
94 };
95 break;
96 default:
97 LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
98 << HloOpcodeString(opcode);
99 }
100
101 auto result = Literal::CreateFromShape(shape);
102 TF_RETURN_IF_ERROR(result->Populate<bool>(
103 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
104 return compare_op(lhs_literal.Get<OperandT>(multi_index),
105 rhs_literal.Get<OperandT>(multi_index));
106 }));
107
108 return std::move(result);
109 }
110
111 template <>
Compare(const Shape & shape,HloOpcode opcode,const Literal & lhs_literal,const Literal & rhs_literal)112 StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
113 const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
114 const Literal& rhs_literal) {
115 std::function<bool(complex64, complex64)> compare_op;
116 switch (opcode) {
117 case HloOpcode::kEq:
118 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
119 return lhs_el == rhs_el;
120 };
121 break;
122 case HloOpcode::kNe:
123 compare_op = [](complex64 lhs_el, complex64 rhs_el) {
124 return lhs_el != rhs_el;
125 };
126 break;
127 default:
128 LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
129 << HloOpcodeString(opcode);
130 }
131
132 auto result = Literal::CreateFromShape(shape);
133 TF_RETURN_IF_ERROR(result->Populate<bool>(
134 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
135 return compare_op(lhs_literal.Get<complex64>(multi_index),
136 rhs_literal.Get<complex64>(multi_index));
137 }));
138
139 return std::move(result);
140 }
141
142 template <typename ReturnT, typename NativeT>
ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)143 StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
144 HloInstruction* instruction,
145 const std::function<ReturnT(NativeT)>& unary_op,
146 const Literal& operand_literal) {
147 const auto shape = instruction->shape();
148 const auto* operand = instruction->operand(0);
149
150 // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
151 // removed.
152 if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
153 return Unimplemented(
154 "Implicit broadcasting is currently unsupported in HLO evaluator "
155 "Shape Mismatch: %s vs %s",
156 ShapeUtil::HumanString(shape).c_str(),
157 ShapeUtil::HumanString(operand->shape()).c_str());
158 }
159
160 auto result = Literal::CreateFromShape(shape);
161
162 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
163 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
164 return unary_op(operand_literal.Get<NativeT>(multi_index));
165 }));
166 return std::move(result);
167 }
168
169 // For one particular placement of a window in a base shape (the placement is
170 // represented as `window_count_index`), iterates inside the window. Translates
171 // the window index into base index. If the base index is within bound, call `f`
172 // with the base index.
IterateThroughWindow(const Shape & window_shape,const Window & window,const Shape & base_shape,const tensorflow::gtl::ArraySlice<int64> & window_count_index,const std::function<void (const std::vector<int64> &)> & f)173 void IterateThroughWindow(
174 const Shape& window_shape, const Window& window, const Shape& base_shape,
175 const tensorflow::gtl::ArraySlice<int64>& window_count_index,
176 const std::function<void(const std::vector<int64>&)>& f) {
177 const int64 rank = ShapeUtil::Rank(base_shape);
178 DimensionVector window_index(rank);
179 std::fill(window_index.begin(), window_index.end(), 0);
180 do {
181 std::vector<int64> base_index(rank);
182 bool out_of_bound = false;
183 for (int64 i = 0; i < rank; ++i) {
184 base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
185 window_index[i] - window.dimensions(i).padding_low();
186 if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
187 out_of_bound = true;
188 break;
189 }
190 }
191 if (!out_of_bound) {
192 f(base_index);
193 }
194 } while (IndexUtil::BumpIndices(window_shape, &window_index));
195 }
196
197 } // namespace
198
199 template <typename ReturnT, typename ElementwiseT>
200 class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
201 public:
TypedVisitor(HloEvaluator * p)202 explicit TypedVisitor(HloEvaluator* p) : parent_(p) {}
203
204 // The following higher-order functions convert a function with ElementwiseT
205 // to a function with ReturnT.
ConvertUnaryFunction(const std::function<ElementwiseT (ElementwiseT)> & unary_op)206 std::function<ReturnT(ReturnT)> ConvertUnaryFunction(
207 const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
208 return [&unary_op](ReturnT arg) {
209 return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg)));
210 };
211 }
ConvertBinaryFunction(const std::function<ElementwiseT (ElementwiseT,ElementwiseT)> & binary_op)212 std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction(
213 const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
214 binary_op) {
215 return [&binary_op](ReturnT arg1, ReturnT arg2) {
216 return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1),
217 static_cast<ElementwiseT>(arg2)));
218 };
219 }
ConvertTernaryFunction(const std::function<ElementwiseT (ElementwiseT,ElementwiseT,ElementwiseT)> & ternary_op)220 std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction(
221 const std::function<ElementwiseT(ElementwiseT, ElementwiseT,
222 ElementwiseT)>& ternary_op) {
223 return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) {
224 return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1),
225 static_cast<ElementwiseT>(arg2),
226 static_cast<ElementwiseT>(arg3)));
227 };
228 }
229
DefaultAction(HloInstruction * hlo_instruction)230 Status DefaultAction(HloInstruction* hlo_instruction) override {
231 return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
232 HloOpcodeString(hlo_instruction->opcode()).c_str());
233 }
234
235 // TODO(b/35950897): many of the stl functions used in the handlers are not
236 // overloaded for every XLA primitive types.
237
238 template <typename NativeT,
239 typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
240 nullptr>
HandleAbs(HloInstruction * abs)241 Status HandleAbs(HloInstruction* abs) {
242 TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
243 ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
244 return elem_operand;
245 }));
246 return Status::OK();
247 }
248
249 template <
250 typename NativeT,
251 typename std::enable_if<std::is_signed<NativeT>::value ||
252 is_complex_t<NativeT>::value>::type* = nullptr>
HandleAbs(HloInstruction * abs)253 Status HandleAbs(HloInstruction* abs) {
254 TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
255 ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) {
256 return std::abs(elem_operand);
257 }));
258 return Status::OK();
259 }
260
HandleAbs(HloInstruction * abs)261 Status HandleAbs(HloInstruction* abs) override {
262 return HandleAbs<ElementwiseT>(abs);
263 }
264
265 template <
266 typename NativeT,
267 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleRound(HloInstruction * round)268 Status HandleRound(HloInstruction* round) {
269 TF_ASSIGN_OR_RETURN(
270 parent_->evaluated_[round],
271 ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) {
272 return std::round(elem_operand);
273 }));
274 return Status::OK();
275 }
276
277 template <
278 typename NativeT,
279 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleRound(HloInstruction * round)280 Status HandleRound(HloInstruction* round) {
281 return InvalidArgument("Unsupported type for Round");
282 }
283
HandleRound(HloInstruction * round)284 Status HandleRound(HloInstruction* round) override {
285 return HandleRound<ReturnT>(round);
286 }
287
HandleBroadcast(HloInstruction * broadcast)288 Status HandleBroadcast(HloInstruction* broadcast) override {
289 parent_->evaluated_[broadcast] =
290 Literal::CreateFromShape(broadcast->shape());
291 auto output = parent_->evaluated_[broadcast].get();
292 const Literal& operand_to_broadcast =
293 parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
294 std::vector<int64> broadcast_indices(
295 ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
296
297 TF_RET_CHECK(broadcast->dimensions().size() ==
298 ShapeUtil::Rank(operand_to_broadcast.shape()))
299 << "broadcast dimensions is of size: " << broadcast->dimensions().size()
300 << " and rank of operand_to_broadcast is: "
301 << ShapeUtil::Rank(operand_to_broadcast.shape());
302 // Checks that operand's dimensions are the same as the broadcast's
303 // dimensions along the dimensions to be broadcasted.
304 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
305 TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
306 operand_to_broadcast.shape().dimensions(i));
307 }
308
309 return output->Populate<ReturnT>(
310 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
311 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
312 broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
313 }
314 return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
315 });
316 }
317
318 template <
319 typename NativeT,
320 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleCeil(HloInstruction * ceil)321 Status HandleCeil(HloInstruction* ceil) {
322 TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
323 ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) {
324 return std::ceil(elem_operand);
325 }));
326 return Status::OK();
327 }
328
329 template <
330 typename NativeT,
331 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleCeil(HloInstruction * ceil)332 Status HandleCeil(HloInstruction* ceil) {
333 return InvalidArgument("Unsupported type for Ceil");
334 }
335
HandleCeil(HloInstruction * ceil)336 Status HandleCeil(HloInstruction* ceil) override {
337 return HandleCeil<ReturnT>(ceil);
338 }
339
HandleConvert(HloInstruction * convert)340 Status HandleConvert(HloInstruction* convert) override {
341 const HloInstruction* operand = convert->operand(0);
342 TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
343 TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
344 parent_->GetEvaluatedLiteralFor(operand).Convert(
345 convert->shape().element_type()));
346
347 if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
348 parent_->evaluated_[convert] = std::move(result);
349 } else {
350 parent_->evaluated_[convert] =
351 result->Relayout(convert->shape().layout());
352 }
353 return Status::OK();
354 }
355
HandleExp(HloInstruction * exp)356 Status HandleExp(HloInstruction* exp) override {
357 TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
358 ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) {
359 return std::exp(elem_operand);
360 }));
361 return Status::OK();
362 }
363
364 template <
365 typename NativeT,
366 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleFloor(HloInstruction * floor)367 Status HandleFloor(HloInstruction* floor) {
368 TF_ASSIGN_OR_RETURN(
369 parent_->evaluated_[floor],
370 ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) {
371 return std::floor(elem_operand);
372 }));
373 return Status::OK();
374 }
375
376 template <
377 typename NativeT,
378 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleFloor(HloInstruction * floor)379 Status HandleFloor(HloInstruction* floor) {
380 return InvalidArgument("Unsupported type for Floor");
381 }
382
HandleFloor(HloInstruction * floor)383 Status HandleFloor(HloInstruction* floor) override {
384 return HandleFloor<ReturnT>(floor);
385 }
386
HandleLog(HloInstruction * log)387 Status HandleLog(HloInstruction* log) override {
388 TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
389 ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
390 return std::log(elem_operand);
391 }));
392 return Status::OK();
393 }
394
395 template <typename NativeT,
396 typename std::enable_if<
397 std::is_integral<NativeT>::value &&
398 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleNot(HloInstruction * not_)399 Status HandleNot(HloInstruction* not_) {
400 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
401 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
402 return ~elem_operand;
403 }));
404 return Status::OK();
405 }
406
407 template <typename NativeT, typename std::enable_if<std::is_floating_point<
408 NativeT>::value>::type* = nullptr>
HandleNot(HloInstruction * not_)409 Status HandleNot(HloInstruction* not_) {
410 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
411 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
412 return !elem_operand;
413 }));
414 return Status::OK();
415 }
416
417 template <typename NativeT,
418 typename std::enable_if<std::is_same<NativeT, bool>::value>::type* =
419 nullptr>
HandleNot(HloInstruction * not_)420 Status HandleNot(HloInstruction* not_) {
421 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
422 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
423 return !elem_operand;
424 }));
425 return Status::OK();
426 }
427
428 template <
429 typename NativeT,
430 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleNot(HloInstruction * not_)431 Status HandleNot(HloInstruction* not_) {
432 return InvalidArgument("Unsupported type for Not");
433 }
434
HandleNot(HloInstruction * not_)435 Status HandleNot(HloInstruction* not_) override {
436 return HandleNot<ElementwiseT>(not_);
437 }
438
439 template <typename NativeT,
440 typename std::enable_if<
441 std::is_signed<NativeT>::value &&
442 !std::is_floating_point<NativeT>::value>::type* = nullptr>
HandleNegate(HloInstruction * negate)443 Status HandleNegate(HloInstruction* negate) {
444 using type = typename std::make_unsigned<NativeT>::type;
445 TF_ASSIGN_OR_RETURN(
446 parent_->evaluated_[negate],
447 ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) {
448 return NativeT(-type(elem_operand));
449 }));
450 return Status::OK();
451 }
452
453 template <typename NativeT,
454 typename std::enable_if<
455 !std::is_signed<NativeT>::value ||
456 std::is_floating_point<NativeT>::value>::type* = nullptr>
HandleNegate(HloInstruction * negate)457 Status HandleNegate(HloInstruction* negate) {
458 TF_ASSIGN_OR_RETURN(
459 parent_->evaluated_[negate],
460 ElementWiseUnaryOp(
461 negate, [](ElementwiseT elem_operand) { return -elem_operand; }));
462 return Status::OK();
463 }
464
HandleNegate(HloInstruction * negate)465 Status HandleNegate(HloInstruction* negate) override {
466 return HandleNegate<ReturnT>(negate);
467 }
468
469 template <
470 typename NativeT,
471 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleSign(HloInstruction * sign)472 Status HandleSign(HloInstruction* sign) {
473 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
474 ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
475 return (ElementwiseT(0) < elem_operand) -
476 (elem_operand < ElementwiseT(0));
477 }));
478 return Status::OK();
479 }
480
481 template <
482 typename NativeT,
483 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleSign(HloInstruction * sign)484 Status HandleSign(HloInstruction* sign) {
485 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
486 ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
487 auto abs_val = std::abs(elem_operand);
488 return 0 == abs_val ? ElementwiseT(0)
489 : elem_operand / abs_val;
490 }));
491 return Status::OK();
492 }
493
HandleSign(HloInstruction * sign)494 Status HandleSign(HloInstruction* sign) override {
495 return HandleSign<ReturnT>(sign);
496 }
497
498 template <typename NativeT, typename std::enable_if<std::is_floating_point<
499 NativeT>::value>::type* = nullptr>
HandleAtan2(HloInstruction * atan2)500 Status HandleAtan2(HloInstruction* atan2) {
501 TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2],
502 ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem,
503 ElementwiseT rhs_elem) {
504 return std::atan2(lhs_elem, rhs_elem);
505 }));
506 return Status::OK();
507 }
508
509 template <typename NativeT, typename std::enable_if<!std::is_floating_point<
510 NativeT>::value>::type* = nullptr>
HandleAtan2(HloInstruction * atan2)511 Status HandleAtan2(HloInstruction* atan2) {
512 return InvalidArgument("Unsupported type for Atan2");
513 }
514
HandleAtan2(HloInstruction * atan2)515 Status HandleAtan2(HloInstruction* atan2) override {
516 return HandleAtan2<ElementwiseT>(atan2);
517 }
518
HandleTanh(HloInstruction * tanh)519 Status HandleTanh(HloInstruction* tanh) override {
520 TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
521 ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) {
522 return std::tanh(elem_operand);
523 }));
524 return Status::OK();
525 }
526
527 template <typename NativeT,
528 typename std::enable_if<
529 std::is_signed<NativeT>::value &&
530 !std::is_floating_point<NativeT>::value>::type* = nullptr>
HandleMultiply(HloInstruction * multiply)531 Status HandleMultiply(HloInstruction* multiply) {
532 using type = typename std::make_unsigned<NativeT>::type;
533 TF_ASSIGN_OR_RETURN(
534 parent_->evaluated_[multiply],
535 ElementWiseBinaryOp(multiply,
536 [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
537 return NativeT(type(lhs_elem) * type(rhs_elem));
538 }));
539 return Status::OK();
540 }
541
542 template <
543 typename NativeT,
544 typename std::enable_if<std::is_unsigned<NativeT>::value ||
545 std::is_floating_point<NativeT>::value ||
546 is_complex_t<NativeT>::value>::type* = nullptr>
HandleMultiply(HloInstruction * multiply)547 Status HandleMultiply(HloInstruction* multiply) {
548 TF_ASSIGN_OR_RETURN(
549 parent_->evaluated_[multiply],
550 ElementWiseBinaryOp(multiply,
551 [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
552 return lhs_elem * rhs_elem;
553 }));
554 return Status::OK();
555 }
556
HandleMultiply(HloInstruction * multiply)557 Status HandleMultiply(HloInstruction* multiply) override {
558 return HandleMultiply<ElementwiseT>(multiply);
559 }
560
HandleSubtract(HloInstruction * subtract)561 Status HandleSubtract(HloInstruction* subtract) override {
562 TF_ASSIGN_OR_RETURN(
563 parent_->evaluated_[subtract],
564 ElementWiseBinaryOp(subtract,
565 [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
566 return lhs_elem - rhs_elem;
567 }));
568 return Status::OK();
569 }
570
HandleAdd(HloInstruction * add)571 Status HandleAdd(HloInstruction* add) override {
572 TF_ASSIGN_OR_RETURN(parent_->evaluated_[add],
573 ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem,
574 ElementwiseT rhs_elem) {
575 return lhs_elem + rhs_elem;
576 }));
577 return Status::OK();
578 }
579
HandleDivide(HloInstruction * divide)580 Status HandleDivide(HloInstruction* divide) override {
581 TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
582 ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
583 ElementwiseT rhs_elem) {
584 return lhs_elem / rhs_elem;
585 }));
586 return Status::OK();
587 }
588
589 template <
590 typename NativeT,
591 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleMaximum(HloInstruction * maximum)592 Status HandleMaximum(HloInstruction* maximum) {
593 TF_ASSIGN_OR_RETURN(
594 parent_->evaluated_[maximum],
595 ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
596 return std::fmax(lhs, rhs);
597 }));
598 return Status::OK();
599 }
600
601 template <
602 typename NativeT,
603 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleMaximum(HloInstruction * maximum)604 Status HandleMaximum(HloInstruction* maximum) {
605 return InvalidArgument("Unsupported type for Maximum");
606 }
607
HandleMaximum(HloInstruction * maximum)608 Status HandleMaximum(HloInstruction* maximum) override {
609 return HandleMaximum<ElementwiseT>(maximum);
610 }
611
612 template <
613 typename NativeT,
614 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleMinimum(HloInstruction * minimum)615 Status HandleMinimum(HloInstruction* minimum) {
616 TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
617 ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
618 ElementwiseT rhs_el) {
619 return std::fmin(lhs_el, rhs_el);
620 }));
621 return Status::OK();
622 }
623
624 template <
625 typename NativeT,
626 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleMinimum(HloInstruction * minimum)627 Status HandleMinimum(HloInstruction* minimum) {
628 return InvalidArgument("Unsupported type for Minimum");
629 }
630
HandleMinimum(HloInstruction * minimum)631 Status HandleMinimum(HloInstruction* minimum) override {
632 return HandleMinimum<ElementwiseT>(minimum);
633 }
634
HandlePower(HloInstruction * power)635 Status HandlePower(HloInstruction* power) override {
636 TF_ASSIGN_OR_RETURN(parent_->evaluated_[power],
637 ElementWiseBinaryOp(power, [](ElementwiseT lhs_el,
638 ElementwiseT rhs_el) {
639 return std::pow(lhs_el, rhs_el);
640 }));
641 return Status::OK();
642 }
643
644 template <
645 typename NativeT,
646 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleRemainder(HloInstruction * remainder)647 Status HandleRemainder(HloInstruction* remainder) {
648 TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
649 ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
650 ElementwiseT rhs_el) {
651 return std::fmod(lhs_el, rhs_el);
652 }));
653 return Status::OK();
654 }
655
656 template <
657 typename NativeT,
658 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleRemainder(HloInstruction * remainder)659 Status HandleRemainder(HloInstruction* remainder) {
660 return InvalidArgument("Unsupported type for Remainder");
661 }
662
HandleRemainder(HloInstruction * remainder)663 Status HandleRemainder(HloInstruction* remainder) override {
664 return HandleRemainder<ElementwiseT>(remainder);
665 }
666
667 template <typename NativeT,
668 typename std::enable_if<std::is_integral<NativeT>::value>::type* =
669 nullptr>
HandleAnd(HloInstruction * and_)670 Status HandleAnd(HloInstruction* and_) {
671 TF_ASSIGN_OR_RETURN(
672 parent_->evaluated_[and_],
673 ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
674 return lhs_el & rhs_el;
675 }));
676 return Status::OK();
677 }
678
679 template <typename NativeT, typename std::enable_if<std::is_floating_point<
680 NativeT>::value>::type* = nullptr>
HandleAnd(HloInstruction * and_)681 Status HandleAnd(HloInstruction* and_) {
682 TF_ASSIGN_OR_RETURN(
683 parent_->evaluated_[and_],
684 ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
685 return lhs_el && rhs_el;
686 }));
687 return Status::OK();
688 }
689
690 template <
691 typename NativeT,
692 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleAnd(HloInstruction * and_)693 Status HandleAnd(HloInstruction* and_) {
694 return InvalidArgument("Unsupported type for And");
695 }
696
HandleAnd(HloInstruction * and_)697 Status HandleAnd(HloInstruction* and_) override {
698 return HandleAnd<ElementwiseT>(and_);
699 }
700
701 template <typename NativeT,
702 typename std::enable_if<std::is_integral<NativeT>::value>::type* =
703 nullptr>
HandleOr(HloInstruction * or_)704 Status HandleOr(HloInstruction* or_) {
705 TF_ASSIGN_OR_RETURN(
706 parent_->evaluated_[or_],
707 ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
708 return lhs_el | rhs_el;
709 }));
710 return Status::OK();
711 }
712
713 template <typename NativeT, typename std::enable_if<std::is_floating_point<
714 NativeT>::value>::type* = nullptr>
HandleOr(HloInstruction * or_)715 Status HandleOr(HloInstruction* or_) {
716 TF_ASSIGN_OR_RETURN(
717 parent_->evaluated_[or_],
718 ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
719 return lhs_el || rhs_el;
720 }));
721 return Status::OK();
722 }
723
724 template <
725 typename NativeT,
726 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleOr(HloInstruction * or_)727 Status HandleOr(HloInstruction* or_) {
728 return InvalidArgument("Unsupported type for Or");
729 }
730
HandleOr(HloInstruction * or_)731 Status HandleOr(HloInstruction* or_) override {
732 return HandleOr<ElementwiseT>(or_);
733 }
734
735 template <typename NativeT,
736 typename std::enable_if<
737 std::is_integral<NativeT>::value &&
738 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleShiftLeft(HloInstruction * shl)739 Status HandleShiftLeft(HloInstruction* shl) {
740 TF_ASSIGN_OR_RETURN(
741 parent_->evaluated_[shl],
742 ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) {
743 return lhs_elem << rhs_elem;
744 }));
745 return Status::OK();
746 }
747
748 template <typename NativeT,
749 typename std::enable_if<!std::is_integral<NativeT>::value ||
750 std::is_same<NativeT, bool>::value>::type* =
751 nullptr>
HandleShiftLeft(HloInstruction *)752 Status HandleShiftLeft(HloInstruction*) {
753 return InvalidArgument("Unsupported type for ShiftLeft");
754 }
755
HandleShiftLeft(HloInstruction * shl)756 Status HandleShiftLeft(HloInstruction* shl) override {
757 return HandleShiftLeft<ElementwiseT>(shl);
758 }
759 template <typename NativeT,
760 typename std::enable_if<
761 std::is_integral<NativeT>::value &&
762 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleShiftRightArithmetic(HloInstruction * shr)763 Status HandleShiftRightArithmetic(HloInstruction* shr) {
764 typedef typename std::make_signed<NativeT>::type SignedT;
765 TF_ASSIGN_OR_RETURN(
766 parent_->evaluated_[shr],
767 ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
768 return static_cast<NativeT>(static_cast<SignedT>(lhs_elem) >>
769 rhs_elem);
770 }));
771 return Status::OK();
772 }
773
774 template <typename NativeT,
775 typename std::enable_if<!std::is_integral<NativeT>::value ||
776 std::is_same<NativeT, bool>::value>::type* =
777 nullptr>
HandleShiftRightArithmetic(HloInstruction *)778 Status HandleShiftRightArithmetic(HloInstruction*) {
779 return InvalidArgument("Unsupported type for ShiftRightArithmetic");
780 }
781
HandleShiftRightArithmetic(HloInstruction * shra)782 Status HandleShiftRightArithmetic(HloInstruction* shra) override {
783 return HandleShiftRightArithmetic<ElementwiseT>(shra);
784 }
785
786 template <typename NativeT,
787 typename std::enable_if<
788 std::is_integral<NativeT>::value &&
789 !std::is_same<NativeT, bool>::value>::type* = nullptr>
HandleShiftRightLogical(HloInstruction * shr)790 Status HandleShiftRightLogical(HloInstruction* shr) {
791 typedef typename std::make_unsigned<NativeT>::type UnsignedT;
792 TF_ASSIGN_OR_RETURN(
793 parent_->evaluated_[shr],
794 ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
795 // If shift amount is greater than the number of bits, then return 0.
796 if (rhs_elem >= sizeof(UnsignedT) * CHAR_BIT) {
797 return static_cast<NativeT>(0);
798 }
799 return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >>
800 rhs_elem);
801 }));
802 return Status::OK();
803 }
804
805 template <typename NativeT,
806 typename std::enable_if<!std::is_integral<NativeT>::value ||
807 std::is_same<NativeT, bool>::value>::type* =
808 nullptr>
HandleShiftRightLogical(HloInstruction *)809 Status HandleShiftRightLogical(HloInstruction*) {
810 return InvalidArgument("Unsupported type for ShiftRightLogical");
811 }
812
HandleShiftRightLogical(HloInstruction * shrl)813 Status HandleShiftRightLogical(HloInstruction* shrl) override {
814 return HandleShiftRightLogical<ElementwiseT>(shrl);
815 }
816
817 template <
818 typename NativeT,
819 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
HandleClamp(HloInstruction * clamp)820 Status HandleClamp(HloInstruction* clamp) {
821 std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
822 clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
823 return std::fmax(low, std::fmin(value, high));
824 };
825 TF_ASSIGN_OR_RETURN(
826 parent_->evaluated_[clamp],
827 ElementwiseTernaryOp(clamp,
828 std::move(ConvertTernaryFunction(clamp_op))));
829 return Status::OK();
830 }
831
832 template <
833 typename NativeT,
834 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
HandleClamp(HloInstruction *)835 Status HandleClamp(HloInstruction*) {
836 return InvalidArgument("Unsupported type for Clamp");
837 }
838
HandleClamp(HloInstruction * clamp)839 Status HandleClamp(HloInstruction* clamp) override {
840 return HandleClamp<ElementwiseT>(clamp);
841 }
842
HandleSelect(HloInstruction * select)843 Status HandleSelect(HloInstruction* select) override {
844 CHECK(!ShapeUtil::IsTuple(select->shape()));
845 std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
846 [](bool pred, ReturnT on_true, ReturnT on_false) {
847 if (pred) {
848 return on_true;
849 }
850 return on_false;
851 };
852 TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
853 ElementwiseTernaryOp(select, std::move(select_op)));
854 return Status::OK();
855 }
856
HandleReverse(HloInstruction * reverse)857 Status HandleReverse(HloInstruction* reverse) override {
858 const auto result_shape = reverse->shape();
859 const auto reverse_dimensions = reverse->dimensions();
860
861 auto operand = reverse->operand(0);
862 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
863 ShapeInference::InferReverseShape(operand->shape(),
864 reverse_dimensions));
865
866 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
867 << "return shape set to: " << ShapeUtil::HumanString(result_shape)
868 << " but is inferred to be: "
869 << ShapeUtil::HumanString(inferred_return_shape);
870
871 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
872 auto result = Literal::CreateFromShape(result_shape);
873
874 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
875 [&](tensorflow::gtl::ArraySlice<int64> out_index) {
876 std::vector<int64> from_index(out_index.begin(), out_index.end());
877 for (const int64 dim : reverse_dimensions) {
878 from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
879 }
880 return operand_literal.Get<ReturnT>(from_index);
881 }));
882
883 parent_->evaluated_[reverse] = std::move(result);
884 return Status::OK();
885 }
886
HandleConvolution(HloInstruction * conv)887 Status HandleConvolution(HloInstruction* conv) override {
888 auto lhs = conv->operand(0);
889 auto rhs = conv->operand(1);
890 const auto& window = conv->window();
891 const Shape& result_shape = conv->shape();
892 const Shape& lhs_shape = lhs->shape();
893 const Shape& rhs_shape = rhs->shape();
894
895 TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
896 TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
897 CHECK(ShapeUtil::IsArray(lhs_shape));
898 CHECK(ShapeUtil::IsArray(rhs_shape));
899 CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
900 CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
901
902 const auto& dnums = conv->convolution_dimension_numbers();
903 const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
904 CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size());
905 CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
906 CHECK_GE(num_spatial_dims, 0);
907 CHECK_EQ(window.dimensions_size(), num_spatial_dims);
908
909 const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
910 const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
911
912 CHECK_EQ(num_spatial_dims + 2, lhs_rank);
913 CHECK_EQ(num_spatial_dims + 2, rhs_rank);
914
915 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
916 ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
917 window, dnums));
918 CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
919 << "return shape set to: " << ShapeUtil::HumanString(result_shape)
920 << " but is inferred to be: "
921 << ShapeUtil::HumanString(inferred_return_shape);
922
923 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
924 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
925
926 // Dimension number applicable for input (lhs).
927 const int64 input_batch_dim = dnums.input_batch_dimension();
928 const int64 input_z_dim = dnums.input_feature_dimension();
929 // Dimension number applicable for kernel (rhs).
930 const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
931 const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
932 // Dimension number applicable for output.
933 const int64 output_batch_dim = dnums.output_batch_dimension();
934 const int64 output_z_dim = dnums.output_feature_dimension();
935
936 const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
937
938 std::vector<int64> window_dimension_sizes;
939 for (auto i : dnums.kernel_spatial_dimensions()) {
940 window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
941 }
942
943 const Shape& window_shape =
944 ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
945
946 DimensionVector lhs_index(lhs_rank);
947 DimensionVector rhs_index(rhs_rank);
948 DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
949
950 auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
951 ElementwiseT result_val = static_cast<ElementwiseT>(0);
952
953 std::fill(lhs_index.begin(), lhs_index.end(), 0);
954 std::fill(rhs_index.begin(), rhs_index.end(), 0);
955 std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
956
957 lhs_index[input_batch_dim] = out_index[output_batch_dim];
958 rhs_index[kernel_output_z_dim] = out_index[output_z_dim];
959
960 // Convolve input feature with kernel.
961 do {
962 for (int64 iz = 0; iz < z_size; ++iz) {
963 lhs_index[input_z_dim] = iz;
964 rhs_index[kernel_input_z_dim] = iz;
965
966 // Find corresponding spatial dimension index for input (lhs).
967 for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
968 // Spatial dimension number for input (lhs) and output.
969 const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
970 const int64 output_spatial_dim =
971 dnums.output_spatial_dimensions(ki);
972
973 // Calculate lhs (input) index without taking base dilation into
974 // account.
975 const auto& window_dim = window.dimensions(ki);
976 const int64 undilated_index =
977 out_index[output_spatial_dim] * window_dim.stride() -
978 window_dim.padding_low() +
979 rhs_spatial_index[ki] * window_dim.window_dilation();
980 // Skip if the lhs (input) index is to be dilated. As an
981 // optimization, skip this mod if there's no dilation.
982 if (window_dim.base_dilation() > 1 &&
983 undilated_index % window_dim.base_dilation() != 0) {
984 goto cnt;
985 }
986
987 // Calculate the actual lhs (input) index after dilation. As an
988 // optimization, skip this integer divide if there's no dilation.
989 if (window_dim.base_dilation() > 1) {
990 lhs_index[input_spatial_dim] =
991 undilated_index / window_dim.base_dilation();
992 } else {
993 lhs_index[input_spatial_dim] = undilated_index;
994 }
995
996 // Skip if input index is not in bound.
997 if (!(lhs_index[input_spatial_dim] >= 0 &&
998 lhs_index[input_spatial_dim] <
999 lhs_shape.dimensions(input_spatial_dim))) {
1000 goto cnt;
1001 }
1002
1003 rhs_index[dnums.kernel_spatial_dimensions(ki)] =
1004 window_dim.window_reversal()
1005 ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
1006 : rhs_spatial_index[ki];
1007 }
1008
1009 result_val +=
1010 static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
1011 static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
1012 }
1013 cnt : {}
1014 } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
1015
1016 return static_cast<ReturnT>(result_val);
1017 };
1018
1019 auto result = Literal::CreateFromShape(result_shape);
1020 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
1021
1022 parent_->evaluated_[conv] = std::move(result);
1023 return Status::OK();
1024 }
1025
HandleDot(HloInstruction * dot)1026 Status HandleDot(HloInstruction* dot) override {
1027 auto lhs = dot->operand(0);
1028 auto rhs = dot->operand(1);
1029 CHECK(ShapeUtil::IsArray(dot->shape()));
1030 CHECK(ShapeUtil::IsArray(lhs->shape()));
1031 CHECK(ShapeUtil::IsArray(rhs->shape()));
1032
1033 const auto& dnums = dot->dot_dimension_numbers();
1034
1035 const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
1036 const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
1037
1038 CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
1039 CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
1040
1041 // There must be 1 and only 1 Contracting dimension for lhs and rhs.
1042 CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1);
1043 CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1);
1044 const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
1045 const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
1046 // Contracted dimension sizes must be the same.
1047 CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension),
1048 rhs->shape().dimensions(rhs_contracting_dimension))
1049 << "lhs contracted dimension: "
1050 << lhs->shape().dimensions(lhs_contracting_dimension)
1051 << " rhs contracted dimension: "
1052 << rhs->shape().dimensions(rhs_contracting_dimension);
1053 const int64 contracted_dimension_size =
1054 lhs->shape().dimensions(lhs_contracting_dimension);
1055
1056 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
1057 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
1058
1059 auto result = Literal::CreateFromShape(dot->shape());
1060
1061 CHECK_EQ(dnums.lhs_batch_dimensions_size(),
1062 dnums.rhs_batch_dimensions_size());
1063
1064 std::vector<int64> lhs_non_contracting_dims;
1065 for (int64 i = 0; i < lhs_rank; i++) {
1066 if (i != lhs_contracting_dimension) {
1067 lhs_non_contracting_dims.push_back(i);
1068 }
1069 }
1070
1071 std::vector<int64> rhs_non_batch_non_contracting_dims;
1072 tensorflow::gtl::FlatSet<int64> batch_dims_set(
1073 dnums.rhs_batch_dimensions().begin(),
1074 dnums.rhs_batch_dimensions().end());
1075 for (int64 i = 0; i < rhs_rank; i++) {
1076 if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
1077 rhs_non_batch_non_contracting_dims.push_back(i);
1078 }
1079 }
1080
1081 const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
1082 const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
1083
1084 DimensionVector lhs_index(lhs_rank);
1085 DimensionVector rhs_index(rhs_rank);
1086 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1087 [&](tensorflow::gtl::ArraySlice<int64> result_index) {
1088 ElementwiseT result_val = static_cast<ElementwiseT>(0);
1089
1090 // Find the corresponding non-contracting indices for lhs and rhs.
1091 //
1092 // For `result_index`, its batch dimension, if exists, will be at the
1093 // same dimension as the batch dimension of lhs and rhs. More
1094 // specifically:
1095 // - For lhs, the non-contracting dimensions, including the batch
1096 // dimension have the same index as the `result_index`.
1097 // - For rhs, the batch dimension is set seperately from other
1098 // non-contracting dimensions, since these other non-contracting
1099 // dimensions in rhs follow the non-contracting dimensions of lhs in
1100 // the resulting index.
1101 //
1102 // As an example, for a resulting index:
1103 // result_index [result_batch, result_x, result_y]
1104 // the effecting lhs and rhs indices are:
1105 // lhs [result_batch, lhs_non_contracting_dim, contracting_dim
1106 // rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
1107 // `result_x` is only affected by the lhs_non_contracting_dim and
1108 // likewise `result_y` only depends on rhs_non_contracting_dim.
1109 //
1110 // so we can look up the lhs and rhs indices by:
1111 //
1112 // lhs:
1113 // batch index is the same as `result_batch`.
1114 // non-contracting dimension is the same as
1115 // result_index[lhs_non_contracting_dim]
1116 // rhs:
1117 // batch index: the same as `result_batch`.
1118 // non-contracting dimension index: *not* the same as
1119 // result_index[rhs_non_contractng_dim], since the
1120 // non-contracting dimensions of lhs are included in the
1121 // result_index first. Instead, the non_contracting_dim of rhs must
1122 // be calculated as following:
1123 // lhs_non_contracting_dimensions_size +
1124 // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
1125 //
1126 // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
1127 // the index offset to the result_index that only depends on
1128 // the non_batch and non-contracting dimensions of rhs. -1 at the
1129 // end translates size to index.
1130 for (auto i : lhs_non_contracting_dims) {
1131 lhs_index[i] = result_index[i];
1132 }
1133 for (auto i : dnums.rhs_batch_dimensions()) {
1134 rhs_index[i] = result_index[i];
1135 }
1136 for (auto i : rhs_non_batch_non_contracting_dims) {
1137 const int64 rhs_non_batch_non_contracting_dim =
1138 lhs_non_contracting_size + (i - batch_dim_size) - 1;
1139 rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
1140 }
1141
1142 // Accumulates resulting product along the contracted dimension.
1143 for (int64 i = 0; i < contracted_dimension_size; ++i) {
1144 lhs_index[lhs_contracting_dimension] = i;
1145 rhs_index[rhs_contracting_dimension] = i;
1146
1147 result_val +=
1148 static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
1149 static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
1150 }
1151
1152 return static_cast<ReturnT>(result_val);
1153 }));
1154
1155 parent_->evaluated_[dot] = std::move(result);
1156 return Status::OK();
1157 }
1158
HandlePad(HloInstruction * pad)1159 Status HandlePad(HloInstruction* pad) override {
1160 CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
1161 // Padding value must be scalar.
1162 CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
1163 CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
1164 pad->padding_config().dimensions_size());
1165
1166 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1167 ShapeInference::InferPadShape(
1168 /*operand_shape=*/pad->operand(0)->shape(),
1169 /*padding_value_shape=*/pad->operand(1)->shape(),
1170 /*padding_config=*/pad->padding_config()));
1171 CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape))
1172 << "return shape is set to: " << ShapeUtil::HumanString(pad->shape())
1173 << "but is inferred to be: "
1174 << ShapeUtil::HumanString(inferred_return_shape);
1175
1176 // Create new HLO of padded shape with padding value.
1177 ReturnT scalar =
1178 parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
1179 auto result = Literal::CreateFromShape(pad->shape());
1180 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1181 [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
1182 return scalar;
1183 }));
1184
1185 const Literal& evaluated_operand =
1186 parent_->GetEvaluatedLiteralFor(pad->operand(0));
1187
1188 std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
1189 0);
1190 std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
1191
1192 // Loop through each element of the operand, assign them to the
1193 // corresponding index of the resulting padded literal.
1194 const PaddingConfig& pad_config = pad->padding_config();
1195
1196 auto func = [&](const std::vector<int64>& input_index) {
1197 for (auto i = 0; i < input_index.size(); ++i) {
1198 // Interior padding occurs logically before edge padding, so in the case
1199 // of negative edge padding elements are removed from the
1200 // interior-padded operand.
1201 target_index[i] =
1202 pad_config.dimensions(i).edge_padding_low() +
1203 input_index[i] * (pad_config.dimensions(i).interior_padding() + 1);
1204
1205 // Account for negative low and high padding: skip assignment if the
1206 // any target index is out of range.
1207 if (!(target_index[i] >= 0 &&
1208 target_index[i] < pad->shape().dimensions(i))) {
1209 return true;
1210 }
1211 }
1212 result->Set<ReturnT>(target_index,
1213 evaluated_operand.Get<ReturnT>(input_index));
1214 return true;
1215 };
1216
1217 std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(),
1218 0);
1219 std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1);
1220
1221 ShapeUtil::ForEachIndex(
1222 evaluated_operand.shape(), zero_base,
1223 AsInt64Slice(evaluated_operand.shape().dimensions()), step, func);
1224
1225 parent_->evaluated_[pad] = std::move(result);
1226 return Status::OK();
1227 }
1228
HandleDynamicSlice(HloInstruction * dynamic_slice)1229 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override {
1230 auto operand = dynamic_slice->operand(0);
1231 auto start_indices = dynamic_slice->operand(1);
1232 auto result_shape = dynamic_slice->shape();
1233 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1234 ShapeInference::InferDynamicSliceShape(
1235 operand->shape(), start_indices->shape(),
1236 dynamic_slice->dynamic_slice_sizes()));
1237 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1238 << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
1239 << "but is inferred to be: "
1240 << ShapeUtil::HumanString(inferred_return_shape);
1241 TF_RET_CHECK(
1242 primitive_util::IsIntegralType(start_indices->shape().element_type()));
1243
1244 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1245 const Literal& start_indices_literal =
1246 parent_->GetEvaluatedLiteralFor(start_indices);
1247
1248 switch (start_indices->shape().element_type()) {
1249 case S32: {
1250 TF_ASSIGN_OR_RETURN(
1251 parent_->evaluated_[dynamic_slice],
1252 DynamicSlice<int32>(operand_literal, start_indices_literal,
1253 result_shape));
1254 } break;
1255 case S64: {
1256 TF_ASSIGN_OR_RETURN(
1257 parent_->evaluated_[dynamic_slice],
1258 DynamicSlice<int64>(operand_literal, start_indices_literal,
1259 result_shape));
1260 } break;
1261 case U32: {
1262 TF_ASSIGN_OR_RETURN(
1263 parent_->evaluated_[dynamic_slice],
1264 DynamicSlice<uint32>(operand_literal, start_indices_literal,
1265 result_shape));
1266 } break;
1267 case U64: {
1268 TF_ASSIGN_OR_RETURN(
1269 parent_->evaluated_[dynamic_slice],
1270 DynamicSlice<uint64>(operand_literal, start_indices_literal,
1271 result_shape));
1272 } break;
1273 default:
1274 LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for "
1275 "start_indices: "
1276 << PrimitiveType_Name(start_indices->shape().element_type());
1277 }
1278
1279 return Status::OK();
1280 }
1281
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1282 Status HandleDynamicUpdateSlice(
1283 HloInstruction* dynamic_update_slice) override {
1284 auto operand = dynamic_update_slice->operand(0);
1285 auto update = dynamic_update_slice->operand(1);
1286 auto start_indices = dynamic_update_slice->operand(2);
1287 auto result_shape = dynamic_update_slice->shape();
1288 TF_ASSIGN_OR_RETURN(
1289 auto inferred_return_shape,
1290 ShapeInference::InferDynamicUpdateSliceShape(
1291 operand->shape(), update->shape(), start_indices->shape()));
1292 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
1293 << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
1294 << "but is inferred to be: "
1295 << ShapeUtil::HumanString(inferred_return_shape);
1296 TF_RET_CHECK(
1297 primitive_util::IsIntegralType(start_indices->shape().element_type()));
1298 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape()));
1299
1300 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1301 const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update);
1302 const Literal& start_indices_literal =
1303 parent_->GetEvaluatedLiteralFor(start_indices);
1304
1305 switch (start_indices->shape().element_type()) {
1306 case S32: {
1307 TF_ASSIGN_OR_RETURN(
1308 parent_->evaluated_[dynamic_update_slice],
1309 DynamicUpdateSlice<int32>(operand_literal, update_literal,
1310 start_indices_literal));
1311 } break;
1312 case S64: {
1313 TF_ASSIGN_OR_RETURN(
1314 parent_->evaluated_[dynamic_update_slice],
1315 DynamicUpdateSlice<int64>(operand_literal, update_literal,
1316 start_indices_literal));
1317 } break;
1318 case U32: {
1319 TF_ASSIGN_OR_RETURN(
1320 parent_->evaluated_[dynamic_update_slice],
1321 DynamicUpdateSlice<uint32>(operand_literal, update_literal,
1322 start_indices_literal));
1323 } break;
1324 case U64: {
1325 TF_ASSIGN_OR_RETURN(
1326 parent_->evaluated_[dynamic_update_slice],
1327 DynamicUpdateSlice<uint64>(operand_literal, update_literal,
1328 start_indices_literal));
1329 } break;
1330 default:
1331 LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for "
1332 "start_indices: "
1333 << PrimitiveType_Name(start_indices->shape().element_type());
1334 }
1335
1336 return Status::OK();
1337 }
1338
1339 template <typename NativeT>
MapImpl(HloInstruction * map)1340 StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
1341 auto operands = map->operands();
1342 HloComputation* computation = map->to_apply();
1343
1344 auto result = Literal::CreateFromShape(map->shape());
1345
1346 HloEvaluator embedded_evaluator;
1347 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1348 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1349 std::vector<std::unique_ptr<Literal>> arg_literals;
1350 arg_literals.reserve(operands.size());
1351
1352 // Construct scalar literal parameters to be passed to the map
1353 // computation.
1354 for (auto operand : operands) {
1355 const Literal& arg_literal =
1356 parent_->GetEvaluatedLiteralFor(operand);
1357
1358 auto curr_val = arg_literal.Get<NativeT>(multi_index);
1359 auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val);
1360
1361 arg_literals.push_back(std::move(curr_val_literal));
1362 }
1363
1364 std::unique_ptr<Literal> computed_result =
1365 embedded_evaluator
1366 .Evaluate<std::unique_ptr<Literal>>(*computation,
1367 arg_literals)
1368 .ConsumeValueOrDie();
1369 // Clear visit states so that the we can use the evaluate again on
1370 // the same computation.
1371 embedded_evaluator.ResetVisitStates();
1372
1373 return computed_result->Get<ReturnT>({});
1374 }));
1375 return std::move(result);
1376 }
1377
HandleMap(HloInstruction * map)1378 Status HandleMap(HloInstruction* map) override {
1379 switch (map->operand(0)->shape().element_type()) {
1380 case PRED: {
1381 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map));
1382 break;
1383 }
1384 case U8: {
1385 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
1386 break;
1387 }
1388 case U32: {
1389 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
1390 break;
1391 }
1392 case U64: {
1393 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map));
1394 break;
1395 }
1396 case S8: {
1397 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
1398 break;
1399 }
1400 case S32: {
1401 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
1402 break;
1403 }
1404 case S64: {
1405 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
1406 break;
1407 }
1408 case F16: {
1409 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
1410 MapImpl<Eigen::half>(map));
1411 break;
1412 }
1413 case F32: {
1414 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
1415 break;
1416 }
1417 case F64: {
1418 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map));
1419 break;
1420 }
1421 case C64: {
1422 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map));
1423 break;
1424 }
1425 default:
1426 LOG(FATAL) << "HandleMap: unhandled primitive type for "
1427 "input operand: "
1428 << PrimitiveType_Name(
1429 map->operand(0)->shape().element_type());
1430 }
1431
1432 return Status::OK();
1433 }
1434
HandleReduce(HloInstruction * reduce)1435 Status HandleReduce(HloInstruction* reduce) override {
1436 auto arg = reduce->operand(0);
1437 auto init_value = reduce->operand(1);
1438 tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
1439 HloComputation* function = reduce->to_apply();
1440 TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
1441 ShapeUtil::Rank(arg->shape()) - dimensions.size());
1442 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1443 ShapeInference::InferReduceShape(
1444 /*arg=*/arg->shape(),
1445 /*init_value=*/init_value->shape(),
1446 /*dimensions_to_reduce=*/dimensions,
1447 /*to_apply=*/function->ComputeProgramShape()));
1448 TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
1449 << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
1450 << "but is inferred to be: "
1451 << ShapeUtil::HumanString(inferred_return_shape);
1452
1453 const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
1454 VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
1455 const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
1456 VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
1457 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
1458 auto init_scalar = init_literal.Get<ReturnT>({});
1459
1460 auto result = Literal::CreateFromShape(reduce->shape());
1461
1462 const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
1463 std::vector<int64> arg_dim_steps(arg_dimensions.size());
1464 std::vector<int64> arg_dim_counts(arg_dimensions.size());
1465 for (const int64 dim : dimensions) {
1466 arg_dim_steps[dim] = 1;
1467 arg_dim_counts[dim] = arg_dimensions[dim];
1468 }
1469
1470 // Create mapping from result index to arg index.
1471 const int64 result_rank = ShapeUtil::Rank(result->shape());
1472 int64 result_dim = 0;
1473 std::vector<int64> result_to_arg_index(result_rank);
1474 for (int64 i = 0; i < arg_dimensions.size(); ++i) {
1475 if (arg_dim_steps[i] == 0) {
1476 result_to_arg_index[result_dim] = i;
1477 ++result_dim;
1478 }
1479 }
1480
1481 HloEvaluator embedded_evaluator;
1482 // For each resulting dimension, calculate and assign computed value.
1483 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1484 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1485 ReturnT result_val = init_scalar;
1486
1487 std::vector<int64> base(arg_dimensions.size());
1488 for (int64 i = 0; i < multi_index.size(); ++i) {
1489 base[result_to_arg_index[i]] = multi_index[i];
1490 }
1491
1492 auto func = [&](const std::vector<int64>& input_index) {
1493 auto curr_val = arg_literal.Get<ReturnT>(input_index);
1494
1495 // Evaluate computation with specified literal operands.
1496 auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
1497 auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
1498 std::vector<const Literal*> args = {curr_val_literal.get(),
1499 result_val_literal.get()};
1500
1501 std::unique_ptr<Literal> computed_result =
1502 embedded_evaluator.Evaluate<const Literal*>(*function, args)
1503 .ConsumeValueOrDie();
1504 // Clear visit states so that the we can use the evaluate again on
1505 // the same computation.
1506 embedded_evaluator.ResetVisitStates();
1507
1508 // Assign computed result to result_val.
1509 result_val = computed_result->Get<ReturnT>({});
1510
1511 return true;
1512 };
1513
1514 ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
1515 arg_dim_steps, func);
1516
1517 return result_val;
1518 }));
1519
1520 parent_->evaluated_[reduce] = std::move(result);
1521 return Status::OK();
1522 }
1523
HandleSelectAndScatter(HloInstruction * select_and_scatter)1524 Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
1525 auto operand = select_and_scatter->operand(0);
1526 auto source = select_and_scatter->operand(1);
1527 const Window& window = select_and_scatter->window();
1528
1529 const Literal& init_literal =
1530 parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2));
1531 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
1532 auto init_scalar = init_literal.Get<ReturnT>({});
1533
1534 auto result = Literal::CreateFromShape(select_and_scatter->shape());
1535
1536 // Initialize result array with the init value.
1537 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1538 [&](tensorflow::gtl::ArraySlice<int64> output_index) {
1539 return init_scalar;
1540 }));
1541
1542 std::vector<int64> window_dimension_sizes;
1543 for (const auto& window_dimension : window.dimensions()) {
1544 window_dimension_sizes.push_back(window_dimension.size());
1545 }
1546 const Shape window_shape = ShapeUtil::MakeShape(
1547 operand->shape().element_type(), window_dimension_sizes);
1548
1549 HloComputation* select = select_and_scatter->select();
1550 HloComputation* scatter = select_and_scatter->scatter();
1551
1552 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1553 const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
1554
1555 int64 rank = ShapeUtil::Rank(operand_literal.shape());
1556
1557 HloEvaluator embedded_evaluator;
1558 DimensionVector source_index(rank);
1559
1560 std::fill(source_index.begin(), source_index.end(), 0);
1561 do {
1562 // For each element in `source`, we place a window in `operand`. For each
1563 // window placement, we iterate inside the window twice:
1564 //
1565 // 1. Find the selected index by applying `select` function to all
1566 // elements. E.g., If the `select` function is GreaterEqual, the first
1567 // iteration through the window finds the biggest value and returns its
1568 // index.
1569 //
1570 // 2. Using the selected index, scatter value from `source` to result. We
1571 // do this by iterating through the window, and compare each index with
1572 // the selected index.
1573 tensorflow::gtl::optional<ReturnT> selected_val;
1574 tensorflow::gtl::optional<std::vector<int64>> selected_index;
1575
1576 IterateThroughWindow(
1577 window_shape, window, operand_literal.shape(), source_index,
1578 [&](const std::vector<int64>& operand_index) {
1579 auto curr_val = operand_literal.Get<ReturnT>(operand_index);
1580 if (!selected_val) {
1581 selected_val = curr_val;
1582 selected_index = operand_index;
1583 }
1584 const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
1585 const auto selected_val_literal =
1586 Literal::CreateR0<ReturnT>(*selected_val);
1587
1588 const std::vector<const Literal*> args = {
1589 curr_val_literal.get(), selected_val_literal.get()};
1590 std::unique_ptr<Literal> computed_result =
1591 embedded_evaluator.Evaluate<const Literal*>(*select, args)
1592 .ConsumeValueOrDie();
1593 bool selected = computed_result->Get<bool>({});
1594 if (selected) {
1595 selected_val = curr_val;
1596 selected_index = operand_index;
1597 }
1598 embedded_evaluator.ResetVisitStates();
1599 });
1600
1601 IterateThroughWindow(
1602 window_shape, window, operand_literal.shape(), source_index,
1603 [&](const std::vector<int64>& operand_index) {
1604 if (std::equal(operand_index.begin(), operand_index.end(),
1605 selected_index->begin())) {
1606 auto source = source_literal.Get<ReturnT>(source_index);
1607 auto scattered = result->Get<ReturnT>(operand_index);
1608 const auto source_literal = Literal::CreateR0<ReturnT>(source);
1609 const auto scattered_literal =
1610 Literal::CreateR0<ReturnT>(scattered);
1611
1612 const std::vector<const Literal*> args = {
1613 source_literal.get(), scattered_literal.get()};
1614 std::unique_ptr<Literal> computed_result =
1615 embedded_evaluator.Evaluate<const Literal*>(*scatter, args)
1616 .ConsumeValueOrDie();
1617 result->Set(operand_index, computed_result->Get<ReturnT>({}));
1618 // Clear visit states so that the we can use the evaluator again
1619 // on the same computation.
1620 embedded_evaluator.ResetVisitStates();
1621 }
1622 });
1623 } while (IndexUtil::BumpIndices(source->shape(), &source_index));
1624
1625 parent_->evaluated_[select_and_scatter] = std::move(result);
1626 return Status::OK();
1627 }
1628
HandleReduceWindow(HloInstruction * reduce_window)1629 Status HandleReduceWindow(HloInstruction* reduce_window) override {
1630 auto operand = reduce_window->operand(0);
1631 const Window& window = reduce_window->window();
1632 HloComputation* function = reduce_window->to_apply();
1633 TF_ASSIGN_OR_RETURN(
1634 auto inferred_return_shape,
1635 ShapeInference::InferReduceWindowShape(
1636 /*operand_shape=*/reduce_window->operand(0)->shape(),
1637 /*init_value=*/reduce_window->operand(1)->shape(), window,
1638 /*to_apply_shape=*/function->ComputeProgramShape()));
1639 TF_RET_CHECK(
1640 ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
1641 << "return shape is set to: "
1642 << ShapeUtil::HumanStringWithLayout(reduce_window->shape())
1643 << "but is inferred to be: "
1644 << ShapeUtil::HumanStringWithLayout(inferred_return_shape);
1645
1646 const Literal& operand_literal =
1647 parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
1648 VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString();
1649 const Literal& init_literal =
1650 parent_->GetEvaluatedLiteralFor(reduce_window->operand(1));
1651 VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
1652 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
1653 auto init_scalar = init_literal.Get<ReturnT>({});
1654
1655 auto result = Literal::CreateFromShape(reduce_window->shape());
1656
1657 // Creates a Shape object from window, for iteration below.
1658 std::vector<int64> window_dimension_sizes;
1659 for (const auto& window_dimension : window.dimensions()) {
1660 window_dimension_sizes.push_back(window_dimension.size());
1661 }
1662 const Shape window_shape = ShapeUtil::MakeShape(
1663 operand->shape().element_type(), window_dimension_sizes);
1664
1665 DimensionVector window_index(window.dimensions_size());
1666 DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
1667
1668 HloEvaluator embedded_evaluator;
1669 // For each resulting dimension, calculate and assign computed value.
1670 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1671 [&](tensorflow::gtl::ArraySlice<int64> output_index) {
1672 ReturnT result_val = init_scalar;
1673
1674 std::fill(window_index.begin(), window_index.end(), 0);
1675 std::fill(operand_index.begin(), operand_index.end(), 0);
1676
1677 IterateThroughWindow(
1678 window_shape, window, operand_literal.shape(), output_index,
1679 [&](const std::vector<int64>& operand_index) {
1680 auto curr_val = operand_literal.Get<ReturnT>(operand_index);
1681
1682 // Evaluate computation with specified literal operands.
1683 const auto curr_val_literal =
1684 Literal::CreateR0<ReturnT>(curr_val);
1685 const auto result_val_literal =
1686 Literal::CreateR0<ReturnT>(result_val);
1687 const std::vector<const Literal*> args = {
1688 curr_val_literal.get(), result_val_literal.get()};
1689 std::unique_ptr<Literal> computed_result =
1690 embedded_evaluator.Evaluate<const Literal*>(*function, args)
1691 .ConsumeValueOrDie();
1692
1693 // Clear visit states so that the we can use the evaluate again
1694 // on the same computation.
1695 embedded_evaluator.ResetVisitStates();
1696
1697 result_val = computed_result->Get<ReturnT>({});
1698 });
1699
1700 return result_val;
1701 }));
1702
1703 parent_->evaluated_[reduce_window] = std::move(result);
1704 return Status::OK();
1705 }
1706
HandleSlice(HloInstruction * slice)1707 Status HandleSlice(HloInstruction* slice) override {
1708 auto operand = slice->operand(0);
1709 const Shape& shape = slice->shape();
1710 TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
1711 ShapeInference::InferSliceShape(
1712 operand->shape(), slice->slice_starts(),
1713 slice->slice_limits(), slice->slice_strides()));
1714 TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape))
1715 << "return shape set to: " << ShapeUtil::HumanString(shape)
1716 << " but is inferred to be: "
1717 << ShapeUtil::HumanString(inferred_return_shape);
1718
1719 const int64 rank = ShapeUtil::Rank(operand->shape());
1720 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
1721 auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
1722 DimensionVector operand_index(rank);
1723 for (int64 i = 0; i < rank; ++i) {
1724 operand_index[i] =
1725 slice->slice_starts(i) + out_index[i] * slice->slice_strides(i);
1726 }
1727 return operand_literal.Get<ReturnT>(operand_index);
1728 };
1729
1730 auto result = Literal::CreateFromDimensions(
1731 shape.element_type(), AsInt64Slice(shape.dimensions()));
1732 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
1733 parent_->evaluated_[slice] = std::move(result);
1734 return Status::OK();
1735 }
1736
1737 template <typename NativeT, typename std::enable_if<std::is_floating_point<
1738 NativeT>::value>::type* = nullptr>
HandleSin(HloInstruction * sin)1739 Status HandleSin(HloInstruction* sin) {
1740 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin],
1741 ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) {
1742 return std::sin(elem_operand);
1743 }));
1744 return Status::OK();
1745 }
1746
1747 template <
1748 typename NativeT,
1749 typename std::enable_if<std::is_integral<NativeT>::value ||
1750 is_complex_t<NativeT>::value>::type* = nullptr>
HandleSin(HloInstruction * sin)1751 Status HandleSin(HloInstruction* sin) {
1752 return InvalidArgument("Unsupported type for Sin");
1753 }
1754
HandleSin(HloInstruction * sin)1755 Status HandleSin(HloInstruction* sin) override {
1756 return HandleSin<ElementwiseT>(sin);
1757 }
1758
1759 template <typename NativeT, typename std::enable_if<std::is_floating_point<
1760 NativeT>::value>::type* = nullptr>
HandleCos(HloInstruction * cos)1761 Status HandleCos(HloInstruction* cos) {
1762 TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos],
1763 ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) {
1764 return std::cos(elem_operand);
1765 }));
1766 return Status::OK();
1767 }
1768
1769 template <
1770 typename NativeT,
1771 typename std::enable_if<std::is_integral<NativeT>::value ||
1772 is_complex_t<NativeT>::value>::type* = nullptr>
HandleCos(HloInstruction * cos)1773 Status HandleCos(HloInstruction* cos) {
1774 return InvalidArgument("Unsupported type for Cos");
1775 }
1776
HandleCos(HloInstruction * cos)1777 Status HandleCos(HloInstruction* cos) override {
1778 return HandleCos<ElementwiseT>(cos);
1779 }
1780
1781 template <typename NativeT, typename std::enable_if<std::is_same<
1782 float, NativeT>::value>::type* = nullptr>
HandleReducePrecision(HloInstruction * reduce_precision)1783 Status HandleReducePrecision(HloInstruction* reduce_precision) {
1784 TF_ASSIGN_OR_RETURN(
1785 parent_->evaluated_[reduce_precision],
1786 ElementWiseUnaryOp(reduce_precision, [reduce_precision](
1787 ElementwiseT elem) {
1788 uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem);
1789 const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
1790 const uint32_t exponent_bits = reduce_precision->exponent_bits();
1791
1792 // Code is based on the CPU/GPU implementation in LLVM-emitting code.
1793 //
1794 // Bits in float type:
1795 // mantissa : bits [0:22]
1796 // exponent : bits [23:30]
1797 // sign : bits [31]
1798 if (mantissa_bits < 23) {
1799 const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
1800
1801 // Compute rounding bias for round-to-nearest with ties to even.
1802 // This is equal to a base value of 0111... plus one bit if the last
1803 // remaining mantissa bit is 1.
1804 const uint32_t base_rounding_bias =
1805 (last_mantissa_bit_mask >> 1) - 1;
1806 const uint32_t x_last_mantissa_bit =
1807 (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
1808 const uint32_t x_rounding_bias =
1809 x_last_mantissa_bit + base_rounding_bias;
1810
1811 // Add rounding bias, and mask out truncated bits. Note that the
1812 // case where adding the rounding bias overflows into the exponent
1813 // bits is correct; the non-masked mantissa bits will all be zero,
1814 // and the exponent will be incremented by one.
1815 const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
1816 value_as_int = value_as_int + x_rounding_bias;
1817 value_as_int = value_as_int & truncation_mask;
1818 }
1819 if (exponent_bits < 8) {
1820 // Masks for f32 values.
1821 const uint32_t f32_sign_bit_mask = 1u << 31;
1822 const uint32_t f32_exp_bits_mask = 0xffu << 23;
1823
1824 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
1825 // most- significant bit -- is equal to 1.0f for all exponent sizes.
1826 // Adding 2^(n-1)-1 to this gives us the highest non-infinite
1827 // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from
1828 // this gives us the lowest' exponent (corresponding to 0.0f).
1829 //
1830 // Thus, the f32 exponent corresponding to the highest non-infinite
1831 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
1832 // exponent corresponding to the lowest exponent for a bit size of n
1833 // is (2^7-1) - 2^(n-1)-1.
1834 //
1835 // Note that we have already checked that exponents_bits >= 1.
1836 const uint32_t f32_exponent_bias = (1 << 7) - 1;
1837 const uint32_t reduced_exponent_bias =
1838 (1 << (exponent_bits - 1)) - 1;
1839 const uint32_t reduced_max_exponent =
1840 f32_exponent_bias + reduced_exponent_bias;
1841 const uint32_t reduced_min_exponent =
1842 f32_exponent_bias - reduced_exponent_bias;
1843
1844 // Do we overflow or underflow?
1845 const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
1846 const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
1847 const bool x_underflows =
1848 x_exponent <= (reduced_min_exponent << 23);
1849
1850 // Compute appropriately-signed values of zero and infinity.
1851 const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
1852 const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
1853
1854 // Force to zero or infinity if overflow or underflow. (Note that
1855 // this truncates all denormal values to zero, rather than rounding
1856 // them.)
1857 value_as_int = x_overflows ? x_signed_inf : value_as_int;
1858 value_as_int = x_underflows ? x_signed_zero : value_as_int;
1859 }
1860
1861 float reduced_result = tensorflow::bit_cast<float>(value_as_int);
1862 if (std::isnan(elem)) {
1863 reduced_result = mantissa_bits > 0
1864 ? elem
1865 : std::numeric_limits<float>::infinity();
1866 }
1867 return reduced_result;
1868 }));
1869 return Status::OK();
1870 }
1871
1872 template <typename NativeT, typename std::enable_if<std::is_same<
1873 double, NativeT>::value>::type* = nullptr>
HandleReducePrecision(HloInstruction * reduce_precision)1874 Status HandleReducePrecision(HloInstruction* reduce_precision) {
1875 return InvalidArgument("Double not supported for reduce precision");
1876 }
1877
1878 template <
1879 typename NativeT,
1880 typename std::enable_if<std::is_integral<NativeT>::value ||
1881 is_complex_t<NativeT>::value>::type* = nullptr>
HandleReducePrecision(HloInstruction * reduce_precision)1882 Status HandleReducePrecision(HloInstruction* reduce_precision) {
1883 return InvalidArgument("Unsupported type for reduce precision");
1884 }
1885
HandleReducePrecision(HloInstruction * reduce_precision)1886 Status HandleReducePrecision(HloInstruction* reduce_precision) override {
1887 return HandleReducePrecision<ElementwiseT>(reduce_precision);
1888 }
1889
1890 private:
1891 template <typename IndexT>
DynamicSlice(const Literal & operand_literal,const Literal & start_indices_literal,const Shape & result_shape)1892 StatusOr<std::unique_ptr<Literal>> DynamicSlice(
1893 const Literal& operand_literal, const Literal& start_indices_literal,
1894 const Shape& result_shape) {
1895 auto start_indices_typed = start_indices_literal.data<IndexT>();
1896 std::vector<int64> start(start_indices_typed.begin(),
1897 start_indices_typed.end());
1898
1899 std::vector<int64> operand_indices(start.size());
1900
1901 auto result = Literal::CreateFromShape(result_shape);
1902 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1903 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1904 for (int64 i = 0; i < operand_indices.size(); ++i) {
1905 CHECK_GE(multi_index[i] + start[i], 0);
1906 // Mod is only used here to be consistent with the existing
1907 // backends' behavior.
1908 operand_indices[i] = (multi_index[i] + start[i]) %
1909 operand_literal.shape().dimensions(i);
1910 }
1911
1912 auto result = operand_literal.Get<ReturnT>(operand_indices);
1913 return result;
1914 }));
1915
1916 return std::move(result);
1917 }
1918
1919 template <typename IndexT>
DynamicUpdateSlice(const Literal & operand_literal,const Literal & update_literal,const Literal & start_indices_literal)1920 StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
1921 const Literal& operand_literal, const Literal& update_literal,
1922 const Literal& start_indices_literal) {
1923 auto start_indices_typed = start_indices_literal.data<IndexT>();
1924 const std::vector<int64> start(start_indices_typed.begin(),
1925 start_indices_typed.end());
1926
1927 auto result = operand_literal.CloneToUnique();
1928 std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
1929
1930 auto func = [&](const std::vector<int64>& update_index) {
1931 std::transform(update_index.begin(), update_index.end(), start.begin(),
1932 result_index.begin(), std::plus<int64>());
1933
1934 result->Set<ReturnT>(result_index,
1935 update_literal.Get<ReturnT>(update_index));
1936 return true;
1937 };
1938
1939 std::vector<int64> base(update_literal.shape().dimensions_size(), 0);
1940 std::vector<int64> step(update_literal.shape().dimensions_size(), 1);
1941 ShapeUtil::ForEachIndex(update_literal.shape(), base,
1942 AsInt64Slice(update_literal.shape().dimensions()),
1943 step, func);
1944
1945 return std::move(result);
1946 }
1947
ElementWiseUnaryOp(HloInstruction * instruction,const std::function<ElementwiseT (ElementwiseT)> & unary_op)1948 StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
1949 HloInstruction* instruction,
1950 const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
1951 const Literal& operand_literal =
1952 parent_->GetEvaluatedLiteralFor(instruction->operand(0));
1953 TF_ASSIGN_OR_RETURN(
1954 auto result_literal,
1955 (ElementWiseUnaryOpImpl<ReturnT, ReturnT>(
1956 instruction, ConvertUnaryFunction(unary_op), operand_literal)));
1957
1958 return std::move(result_literal);
1959 }
1960
ElementWiseBinaryOp(HloInstruction * instruction,const std::function<ElementwiseT (ElementwiseT,ElementwiseT)> & binary_op)1961 StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
1962 HloInstruction* instruction,
1963 const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
1964 binary_op) {
1965 const auto shape = instruction->shape();
1966 const auto* lhs = instruction->operand(0);
1967 const auto* rhs = instruction->operand(1);
1968
1969 // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast
1970 // is removed.
1971 if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) &&
1972 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
1973 return Unimplemented(
1974 "Implicit broadcasting is currently unsupported in HLO evaluator "
1975 "Shape Mismatch: %s vs %s vs %s: ",
1976 ShapeUtil::HumanString(shape).c_str(),
1977 ShapeUtil::HumanString(lhs->shape()).c_str(),
1978 ShapeUtil::HumanString(rhs->shape()).c_str());
1979 }
1980
1981 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
1982 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
1983
1984 auto result = Literal::CreateFromShape(shape);
1985
1986 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
1987 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
1988 return ConvertBinaryFunction(binary_op)(
1989 lhs_literal.Get<ReturnT>(multi_index),
1990 rhs_literal.Get<ReturnT>(multi_index));
1991 }));
1992 return std::move(result);
1993 }
1994
1995 template <typename LhsType, typename RhsType, typename EhsType>
ElementwiseTernaryOp(HloInstruction * instruction,const std::function<ReturnT (LhsType,RhsType,EhsType)> & ternary_op)1996 StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
1997 HloInstruction* instruction,
1998 const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
1999 const auto shape = instruction->shape();
2000 const auto* lhs = instruction->operand(0);
2001 const auto* rhs = instruction->operand(1);
2002 const auto* ehs = instruction->operand(2);
2003
2004 // TODO(b/35950897, b/27796129): add DCHECK back once implicit
2005 // broadcast is removed.
2006 if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) &&
2007 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) &&
2008 ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) {
2009 return Unimplemented(
2010 "Implicit broadcasting is currently unsupported in HLO evaluator "
2011 "Shape Mismatch: %s vs %s vs %s vs %s: ",
2012 ShapeUtil::HumanString(shape).c_str(),
2013 ShapeUtil::HumanString(lhs->shape()).c_str(),
2014 ShapeUtil::HumanString(rhs->shape()).c_str(),
2015 ShapeUtil::HumanString(ehs->shape()).c_str());
2016 }
2017
2018 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
2019 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
2020 const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
2021
2022 auto result = Literal::CreateFromShape(shape);
2023
2024 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
2025 [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
2026 return ternary_op(lhs_literal.Get<LhsType>(multi_index),
2027 rhs_literal.Get<RhsType>(multi_index),
2028 ehs_literal.Get<EhsType>(multi_index));
2029 }));
2030
2031 return std::move(result);
2032 }
2033
2034 HloEvaluator* parent_;
2035 }; // class HloEvaluator::TypedVisitor
2036
HloEvaluator()2037 HloEvaluator::HloEvaluator() {
2038 typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this);
2039 typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this);
2040 typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2041 return Unimplemented("HloEvaluator: unhandled primitive type: U16.");
2042 });
2043 typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this);
2044 typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this);
2045 typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this);
2046 typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2047 return Unimplemented("HloEvaluator: unhandled primitive type: S16.");
2048 });
2049 typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
2050 typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
2051 typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this);
2052 typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
2053 typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
2054 typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
2055
2056 // Most of the evaluator computations we use don't support BF16 (e.g.,
2057 // std::ceil, std::tanh). To make evaluator work with BF16, we set all
2058 // elementwise computations to be done in F32 and do BF16<->F32 conversion
2059 // around the input and the output of the computations.
2060 typed_visitors_[BF16] = MakeUnique<TypedVisitor<bfloat16, float>>(this);
2061 typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2062 return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
2063 });
2064 typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
2065 return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE.");
2066 });
2067 }
2068
2069 template <typename LiteralPtr>
Evaluate(const HloModule & module,tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals)2070 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2071 const HloModule& module,
2072 tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) {
2073 XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
2074
2075 evaluated_.clear();
2076 arg_literals_.clear();
2077 for (const auto& literal_ptr : arg_literals) {
2078 arg_literals_.push_back(&*literal_ptr);
2079 }
2080
2081 TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
2082
2083 return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
2084 .CloneToUnique();
2085 }
2086
2087 template <typename LiteralPtr>
Evaluate(const HloComputation & computation,tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals)2088 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2089 const HloComputation& computation,
2090 tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) {
2091 XLA_VLOG_LINES(
2092 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
2093
2094 evaluated_.clear();
2095 arg_literals_.clear();
2096 for (const auto& literal_ptr : arg_literals) {
2097 arg_literals_.push_back(&*literal_ptr);
2098 }
2099
2100 TF_RETURN_IF_ERROR(computation.Accept(this));
2101 return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
2102 }
2103
2104 template <typename LiteralPtr>
Evaluate(HloInstruction * instruction,tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals)2105 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2106 HloInstruction* instruction,
2107 tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) {
2108 TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
2109 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
2110
2111 evaluated_.clear();
2112 arg_literals_.clear();
2113 for (const auto& literal_ptr : arg_literals) {
2114 arg_literals_.push_back(&*literal_ptr);
2115 }
2116
2117 // Evaluate operands of Parameter type against the input literals which
2118 // caches the evaluated literal results.
2119 for (const auto operand : instruction->operands()) {
2120 if (operand->opcode() == HloOpcode::kParameter) {
2121 const Literal* input_literal = arg_literals_[operand->parameter_number()];
2122 VLOG(2) << "Parameter operand evaluated to: "
2123 << input_literal->ToString();
2124 TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
2125
2126 evaluated_[operand] = input_literal->CloneToUnique();
2127 }
2128 }
2129
2130 TF_RETURN_IF_ERROR(Preprocess(instruction));
2131 TF_RETURN_IF_ERROR(instruction->Visit(this));
2132 TF_RETURN_IF_ERROR(Postprocess(instruction));
2133 return GetEvaluatedLiteralFor(instruction).CloneToUnique();
2134 }
2135
Evaluate(HloInstruction * instruction)2136 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
2137 HloInstruction* instruction) {
2138 if (instruction->opcode() == HloOpcode::kParameter) {
2139 return tensorflow::errors::FailedPrecondition(
2140 "Cannot evaluate a parameter.");
2141 }
2142 if (!hlo_query::AllOperandsAreConstants(*instruction)) {
2143 return tensorflow::errors::FailedPrecondition(
2144 "Not all operands are constants.");
2145 }
2146 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
2147
2148 arg_literals_.clear();
2149 evaluated_.clear();
2150
2151 TF_RETURN_IF_ERROR(Preprocess(instruction));
2152 TF_RETURN_IF_ERROR(instruction->Visit(this));
2153 TF_RETURN_IF_ERROR(Postprocess(instruction));
2154 return GetEvaluatedLiteralFor(instruction).CloneToUnique();
2155 }
2156
TryEvaluate(HloInstruction * instruction)2157 std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
2158 HloInstruction* instruction) {
2159 auto result_or = Evaluate(instruction);
2160 if (!result_or.ok()) {
2161 VLOG(1) << "TryEvaluate failed:" << result_or.status();
2162 return nullptr;
2163 }
2164
2165 return result_or.ConsumeValueOrDie();
2166 }
2167
EvaluateWithSubstitutions(const HloInstruction * instruction,const std::unordered_map<const HloInstruction *,const Literal * > & substitutions)2168 StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
2169 const HloInstruction* instruction,
2170 const std::unordered_map<const HloInstruction*, const Literal*>&
2171 substitutions) {
2172 std::vector<std::unique_ptr<HloInstruction>> owned_operands;
2173 for (const HloInstruction* operand : instruction->operands()) {
2174 auto it = substitutions.find(operand);
2175 if (it == substitutions.end()) {
2176 owned_operands.push_back(operand->Clone());
2177 } else {
2178 owned_operands.push_back(
2179 HloInstruction::CreateConstant(it->second->CloneToUnique()));
2180 }
2181 }
2182
2183 std::vector<HloInstruction*> operands;
2184 operands.reserve(owned_operands.size());
2185 for (auto& operand : owned_operands) {
2186 operands.push_back(operand.get());
2187 }
2188
2189 std::unique_ptr<HloInstruction> cloned_instruction =
2190 instruction->CloneWithNewOperands(instruction->shape(), operands);
2191 auto result = Evaluate(cloned_instruction.get());
2192
2193 // Clean up our cloned instructions before returning.
2194 cloned_instruction->DetachFromOperands();
2195 for (auto& operand : owned_operands) {
2196 operand->DetachFromOperands();
2197 }
2198
2199 return result;
2200 }
2201
HandleParameter(HloInstruction * parameter)2202 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
2203 CHECK_LT(parameter->parameter_number(), arg_literals_.size());
2204 const Literal* input_literal = arg_literals_[parameter->parameter_number()];
2205 VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
2206 DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()))
2207 << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape())
2208 << ", but input literal shape is: "
2209 << ShapeUtil::HumanString(input_literal->shape());
2210
2211 evaluated_[parameter] = input_literal->CloneToUnique();
2212 return Status::OK();
2213 }
2214
HandleConstant(HloInstruction *)2215 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
2216
HandleReshape(HloInstruction * reshape)2217 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
2218 TF_ASSIGN_OR_RETURN(
2219 evaluated_[reshape],
2220 GetEvaluatedLiteralFor(reshape->operand(0))
2221 .Reshape(AsInt64Slice(reshape->shape().dimensions())));
2222 return Status::OK();
2223 }
2224
HandleTranspose(HloInstruction * transpose)2225 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
2226 evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
2227 .Transpose(transpose->dimensions());
2228 return Status::OK();
2229 }
2230
HandleConcatenate(HloInstruction * concatenate)2231 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
2232 tensorflow::gtl::ArraySlice<HloInstruction*> operands(
2233 concatenate->operands());
2234 // The result concatenate dimension is going to be the sum of all
2235 // concatenate dimensions of the operands taking part of the operation.
2236 const Shape& reference_shape = operands[0]->shape();
2237 CHECK(!ShapeUtil::IsTuple(reference_shape));
2238 const int64 rank = ShapeUtil::Rank(reference_shape);
2239 const int64 concat_dim = concatenate->dimensions()[0];
2240 CHECK_GE(concat_dim, 0);
2241 CHECK_LT(concat_dim, rank);
2242
2243 DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
2244 reference_shape.dimensions().end());
2245
2246 for (int64 i = 1; i < operands.size(); ++i) {
2247 const Shape& operand_shape = operands[i]->shape();
2248 CHECK(!ShapeUtil::IsTuple(operand_shape));
2249 // Accumulate the concat dimension from all tensors taking part to the
2250 // operation.
2251 concat_dimensions[concat_dim] +=
2252 ShapeUtil::GetDimension(operand_shape, concat_dim);
2253 }
2254
2255 auto result_literal = Literal::CreateFromDimensions(
2256 reference_shape.element_type(), concat_dimensions);
2257 DimensionVector source_indices(rank, 0);
2258 DimensionVector dest_indices(concat_dimensions.size(), 0);
2259
2260 for (auto operand : operands) {
2261 const Shape& operand_shape = operand->shape();
2262 TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
2263 GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
2264 AsInt64Slice(operand_shape.dimensions())));
2265 dest_indices[concat_dim] +=
2266 ShapeUtil::GetDimension(operand_shape, concat_dim);
2267 }
2268
2269 evaluated_[concatenate] = std::move(result_literal);
2270 return Status::OK();
2271 }
2272
HandleIsFinite(HloInstruction * is_finite)2273 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
2274 auto operand = is_finite->operand(0);
2275 if (!ShapeUtil::ElementIsFloating(operand->shape())) {
2276 return InvalidArgument(
2277 "expected element type in shape to be float for IsFinite op, got: %s",
2278 PrimitiveType_Name(operand->shape().element_type()).c_str());
2279 }
2280
2281 switch (operand->shape().element_type()) {
2282 case F16:
2283 return Unimplemented("unhandled primitive type: F16.");
2284 case F32: {
2285 auto result_or = ElementWiseUnaryOpImpl<bool, float>(
2286 is_finite,
2287 [](float elem_operand) { return std::isfinite(elem_operand); },
2288 GetEvaluatedLiteralFor(operand));
2289 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
2290 break;
2291 }
2292 case F64: {
2293 auto result_or = ElementWiseUnaryOpImpl<bool, double>(
2294 is_finite,
2295 [](double elem_operand) { return std::isfinite(elem_operand); },
2296 GetEvaluatedLiteralFor(operand));
2297 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
2298 break;
2299 }
2300 default:
2301 LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: "
2302 << PrimitiveType_Name(operand->shape().element_type());
2303 }
2304
2305 return Status::OK();
2306 }
2307
HandleCompare(HloInstruction * compare)2308 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
2309 HloOpcode opcode = compare->opcode();
2310 auto lhs = compare->operand(0);
2311 auto rhs = compare->operand(1);
2312 // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
2313 // removed.
2314 if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
2315 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
2316 return Unimplemented(
2317 "Implicit broadcasting is currently unsupported in HLO evaluator "
2318 "Shape Mismatch: %s vs %s vs %s",
2319 ShapeUtil::HumanString(compare->shape()).c_str(),
2320 ShapeUtil::HumanString(lhs->shape()).c_str(),
2321 ShapeUtil::HumanString(rhs->shape()).c_str());
2322 }
2323
2324 TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
2325
2326 const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
2327 const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
2328
2329 // Note here we switch on the operand's type.
2330 switch (lhs->shape().element_type()) {
2331 case PRED: {
2332 TF_ASSIGN_OR_RETURN(
2333 evaluated_[compare],
2334 Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
2335 } break;
2336 case U8: {
2337 TF_ASSIGN_OR_RETURN(
2338 evaluated_[compare],
2339 Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
2340 } break;
2341 case U16:
2342 return Unimplemented("unhandled primitive type: U16.");
2343 case U32: {
2344 TF_ASSIGN_OR_RETURN(
2345 evaluated_[compare],
2346 Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
2347 } break;
2348 case U64: {
2349 TF_ASSIGN_OR_RETURN(
2350 evaluated_[compare],
2351 Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
2352 } break;
2353 case S8: {
2354 TF_ASSIGN_OR_RETURN(
2355 evaluated_[compare],
2356 Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
2357 } break;
2358 case S16:
2359 return Unimplemented("unhandled primitive type: S16.");
2360 case S32: {
2361 TF_ASSIGN_OR_RETURN(
2362 evaluated_[compare],
2363 Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
2364 } break;
2365 case S64: {
2366 TF_ASSIGN_OR_RETURN(
2367 evaluated_[compare],
2368 Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
2369 } break;
2370 case F16:
2371 return Unimplemented("unhandled primitive type: F16.");
2372 case F32: {
2373 TF_ASSIGN_OR_RETURN(
2374 evaluated_[compare],
2375 Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
2376 } break;
2377 case F64: {
2378 TF_ASSIGN_OR_RETURN(
2379 evaluated_[compare],
2380 Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
2381 } break;
2382 case C64: {
2383 TF_ASSIGN_OR_RETURN(evaluated_[compare],
2384 Compare<complex64>(compare->shape(), opcode,
2385 lhs_literal, rhs_literal));
2386 } break;
2387 default:
2388 LOG(FATAL) << "HandleCompare: unknown primitive type: "
2389 << PrimitiveType_Name(lhs->shape().element_type());
2390 }
2391
2392 return Status::OK();
2393 }
2394
HandleTuple(HloInstruction * tuple)2395 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
2396 std::vector<const Literal*> operand_literals;
2397 for (auto operand : tuple->operands()) {
2398 operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
2399 }
2400
2401 evaluated_[tuple] = Literal::MakeTuple(operand_literals);
2402 return Status::OK();
2403 }
2404
HandleGetTupleElement(HloInstruction * get_tuple_element)2405 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
2406 const auto result_shape = get_tuple_element->shape();
2407 const int64 index = get_tuple_element->tuple_index();
2408
2409 auto operand = get_tuple_element->operand(0);
2410 TF_ASSIGN_OR_RETURN(
2411 auto inferred_return_shape,
2412 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
2413 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
2414 << "return shape set to: " << ShapeUtil::HumanString(result_shape)
2415 << " but is inferred to be: "
2416 << ShapeUtil::HumanString(inferred_return_shape);
2417
2418 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
2419
2420 evaluated_[get_tuple_element] = MakeUnique<Literal>(
2421 ShapeUtil::GetTupleElementShape(operand->shape(), index));
2422 return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
2423 /*dest_shape_index=*/{},
2424 /*src_shape_index=*/{index});
2425 }
2426
HandleCopy(HloInstruction * copy)2427 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
2428 TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
2429
2430 auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
2431 evaluated_[copy] = std::move(result);
2432 return Status::OK();
2433 }
2434
Preprocess(HloInstruction * hlo)2435 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
2436 VLOG(2) << "About to visit HLO: " << hlo->ToString();
2437 return Status::OK();
2438 }
2439
Postprocess(HloInstruction * hlo)2440 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
2441 VLOG(2) << "Finished visiting " << hlo->ToString()
2442 << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
2443 return Status::OK();
2444 }
2445
2446 // Explicit instantiation of templatized Evaluate* methods.
2447 //
2448 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
2449 const Literal*>(const HloModule& module,
2450 tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
2451 template StatusOr<std::unique_ptr<Literal>>
2452 HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
2453 const HloModule& module,
2454 tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals);
2455
2456 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
2457 const Literal*>(const HloComputation& computation,
2458 tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
2459 template StatusOr<std::unique_ptr<Literal>>
2460 HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
2461 const HloComputation& computation,
2462 tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals);
2463
2464 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
2465 const Literal*>(HloInstruction* instruction,
2466 tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
2467 template StatusOr<std::unique_ptr<Literal>>
2468 HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
2469 HloInstruction* instruction,
2470 tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals);
2471
2472 } // namespace xla
2473