1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/value_inference.h"
16
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
31 #include "tensorflow/compiler/xla/service/hlo.pb.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/stream_executor/lib/statusor.h"
41
42 namespace xla {
43 namespace {
CreatePredLiteral(bool pred,const Shape & reference_shape)44 Literal CreatePredLiteral(bool pred, const Shape& reference_shape) {
45 if (reference_shape.IsTuple()) {
46 std::vector<Literal> sub_literals;
47 const auto& reference_shape_tuple_shapes = reference_shape.tuple_shapes();
48 sub_literals.reserve(reference_shape_tuple_shapes.size());
49 for (const Shape& shape : reference_shape_tuple_shapes) {
50 sub_literals.emplace_back(CreatePredLiteral(pred, shape));
51 }
52 return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
53 }
54 PrimitiveType element_type = reference_shape.element_type();
55 if (element_type == TOKEN) {
56 return LiteralUtil::CreateR0(pred);
57 }
58 Literal literal = LiteralUtil::CreateR0(pred);
59 Literal literal_broadcast =
60 literal.Broadcast(ShapeUtil::ChangeElementType(reference_shape, PRED), {})
61 .ValueOrDie();
62 return literal_broadcast;
63 }
64
CreateS64Literal(int64_t value,const Shape & reference_shape)65 Literal CreateS64Literal(int64_t value, const Shape& reference_shape) {
66 if (reference_shape.IsTuple()) {
67 std::vector<Literal> sub_literals;
68 const auto& reference_shape_tuple_shapes = reference_shape.tuple_shapes();
69 sub_literals.reserve(reference_shape_tuple_shapes.size());
70 for (const Shape& shape : reference_shape_tuple_shapes) {
71 sub_literals.emplace_back(CreateS64Literal(value, shape));
72 }
73 return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
74 }
75 PrimitiveType element_type = reference_shape.element_type();
76 if (element_type == TOKEN) {
77 return LiteralUtil::CreateToken();
78 }
79 Literal literal = LiteralUtil::CreateR0<int64_t>(value);
80 return literal
81 .Broadcast(ShapeUtil::ChangeElementType(reference_shape, S64), {})
82 .ValueOrDie();
83 }
84
85 // Create a literal with garbage data. The data inside is undefined and
86 // shouldn't be used in any meaningful computation.
CreateGarbageLiteral(const Shape & reference_shape)87 Literal CreateGarbageLiteral(const Shape& reference_shape) {
88 if (reference_shape.IsTuple()) {
89 std::vector<Literal> sub_literals;
90 for (const Shape& shape : reference_shape.tuple_shapes()) {
91 sub_literals.emplace_back(CreateGarbageLiteral(shape));
92 }
93 return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
94 }
95 PrimitiveType element_type = reference_shape.element_type();
96 if (element_type == TOKEN) {
97 return LiteralUtil::CreateToken();
98 }
99 Literal literal = LiteralUtil::One(element_type);
100 return literal.Broadcast(reference_shape, {}).ValueOrDie();
101 }
102
103 // HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
104 // to provide operand as literals through the get_operand function.
105 struct HloProtoEvaluator {
HloProtoEvaluatorxla::__anon1970efe50111::HloProtoEvaluator106 explicit HloProtoEvaluator(HloEvaluator& evaluator, HloInstructionProto inst)
107 : evaluator(evaluator),
108 inst(std::move(inst)),
109 module("EmptyModuleForEvaluation", HloModuleConfig()) {}
110
111 // WithOpCode changes the called computation of the instruction being
112 // evaluated.
WithComputationxla::__anon1970efe50111::HloProtoEvaluator113 HloProtoEvaluator& WithComputation(
114 std::unique_ptr<HloComputation> new_computation) {
115 computation = new_computation.get();
116 computation->ClearUniqueIdInternal();
117 for (HloInstruction* inst : computation->instructions()) {
118 inst->ClearUniqueIdInternal();
119 }
120 module.AddEmbeddedComputation(std::move(new_computation));
121 return *this;
122 }
123
124 // WithPrimitiveType changes the primitive type of the instruction being
125 // evaluated.
WithPrimitiveTypexla::__anon1970efe50111::HloProtoEvaluator126 HloProtoEvaluator& WithPrimitiveType(PrimitiveType new_primitive_type) {
127 primitive_type = new_primitive_type;
128 return *this;
129 }
130
131 // WithOpCode changes the opcode of the instruction being evaluated.
WithOpCodexla::__anon1970efe50111::HloProtoEvaluator132 HloProtoEvaluator& WithOpCode(HloOpcode new_opcode) {
133 opcode = new_opcode;
134 return *this;
135 }
136
137 // WithOperands changes the operands of the instruction being evaluated.
WithOperandsxla::__anon1970efe50111::HloProtoEvaluator138 HloProtoEvaluator& WithOperands(absl::Span<Literal> operands) {
139 this->operands = operands;
140 return *this;
141 }
142
143 // When WithSubshape is set, the result tuple shape will be decomposed and
144 // specific the literal will be returned.
WithSubshapexla::__anon1970efe50111::HloProtoEvaluator145 HloProtoEvaluator& WithSubshape(ShapeIndex shape_index) {
146 this->shape_index = std::move(shape_index);
147 return *this;
148 }
149
Evaluatexla::__anon1970efe50111::HloProtoEvaluator150 StatusOr<Literal> Evaluate() {
151 // Evaluate the instruction by swapping it's operands with constant
152 // instructions with given literals.
153 HloComputation::Builder builder("EmptyComputation");
154 absl::flat_hash_map<int64_t, HloInstruction*> operand_map;
155 for (int64_t i = 0; i < inst.operand_ids_size(); ++i) {
156 int64_t operand_handle = inst.operand_ids(i);
157 std::unique_ptr<HloInstruction> operand =
158 HloInstruction::CreateConstant(operands[i].Clone());
159 operand_map[operand_handle] = operand.get();
160 builder.AddInstruction(std::move(operand));
161 }
162
163 if (primitive_type.has_value()) {
164 *inst.mutable_shape() = ShapeUtil::ChangeElementType(
165 Shape(inst.shape()), primitive_type.value())
166 .ToProto();
167 }
168 if (opcode.has_value()) {
169 *inst.mutable_opcode() = HloOpcodeString(opcode.value());
170 }
171 absl::flat_hash_map<int64_t, HloComputation*> computation_map;
172 if (inst.called_computation_ids_size() != 0) {
173 TF_RET_CHECK(inst.called_computation_ids_size() == 1 &&
174 computation != nullptr)
175 << inst.DebugString();
176 computation_map[inst.called_computation_ids(0)] = computation;
177 }
178 TF_ASSIGN_OR_RETURN(
179 auto new_instruction,
180 HloInstruction::CreateFromProto(inst, operand_map, computation_map));
181 new_instruction->ClearUniqueIdInternal();
182 builder.AddInstruction(std::move(new_instruction));
183 auto computation = builder.Build();
184 module.AddEntryComputation(std::move(computation));
185 if (shape_index.empty()) {
186 return evaluator.Evaluate(module.entry_computation()->root_instruction());
187 } else {
188 TF_ASSIGN_OR_RETURN(
189 auto result,
190 evaluator.Evaluate(module.entry_computation()->root_instruction()));
191 return result.SubLiteral(this->shape_index);
192 }
193 }
194
195 HloEvaluator& evaluator;
196 HloInstructionProto inst;
197
198 HloModule module;
199 absl::Span<Literal> operands;
200 ShapeIndex shape_index = {};
201 HloComputation* computation = nullptr;
202 std::optional<PrimitiveType> primitive_type = std::nullopt;
203 std::optional<HloOpcode> opcode = std::nullopt;
204 };
205
206 enum PostorderDFSNodeType {
207 // This node is about figuring out the constant value.
208 kConstantValue = 0,
209 // This node is about figuring out the constant bound.
210 kConstantUpperBound,
211 kConstantLowerBound,
212 // This node is about figuring out whether a value is dynamic.
213 kValueIsDynamic,
214 // This node is about figuring out whether a bound value is dynamic. It's
215 // similar to kValueIsDynamic, but views shape bound as static values.
216 kBoundIsDynamic,
217 };
218
PostorderDFSNodeTypeToString(PostorderDFSNodeType type)219 std::string PostorderDFSNodeTypeToString(PostorderDFSNodeType type) {
220 switch (type) {
221 case kConstantValue:
222 return "kConstantValue";
223 case kConstantUpperBound:
224 return "kConstantUpperBound";
225 case kConstantLowerBound:
226 return "kConstantLowerBound";
227 case kValueIsDynamic:
228 return "kValueIsDynamic";
229 case kBoundIsDynamic:
230 return "kBoundIsDynamic";
231 }
232 }
233
234 struct InferenceContext {
InferenceContextxla::__anon1970efe50111::InferenceContext235 explicit InferenceContext(ShapeIndex shape_index,
236 std::vector<int64_t> caller_operand_handles)
237 : shape_index(std::move(shape_index)),
238 caller_operand_handles(std::move(caller_operand_handles)) {}
239 // `shape_index` represents the subshape that we care about in the inference.
240 // It is used to avoid meterializing the whole tuple when we only care about a
241 // sub tensor of it.
242 ShapeIndex shape_index;
243
244 // caller_operand_handles is a stack that helps argument forwarding. The top
245 // of the stack represents the tensor to be forwarded to the
246 // parameter of the inner most function. E.g.,:
247 // inner_true_computation {
248 // inner_p0 = param()
249 // ...
250 // }
251 //
252 // true_computaion {
253 // p0 = param()
254 // conditional(pred, p0, inner_true_computation,
255 // ...)
256 // }
257 //
258 // main {
259 // op = ..
260 // conditional(pred, op, true_computation, ...)
261 // }
262 //
263 // In this case, when we analyze inner_true_computation, the
264 // `caller_operand_handlers` will be [op, p0] -- p0 is what should be
265 // forwarded to inner_p0 and op is what should be forwarded to p0. similarly,
266 // when we analyze true_computation, the `caller_operand_handlers` will be
267 // [op].
268 std::vector<int64_t> caller_operand_handles;
269 };
270
271 // Each node in the postorder traversal tree may depend on traversing the
272 // values of the node's children.
273 struct PostorderDFSDep {
PostorderDFSDepxla::__anon1970efe50111::PostorderDFSDep274 explicit PostorderDFSDep(int64_t handle, PostorderDFSNodeType type,
275 InferenceContext context, std::string annotation)
276 : handle(handle),
277 type(type),
278 context(std::move(context)),
279 annotation(std::move(annotation)) {}
280 int64_t handle;
281 PostorderDFSNodeType type;
282 InferenceContext context;
283 std::string annotation;
284 };
285
286 // This function represents the logic to visit a node once its dependencies
287 // (operands) are all resolved.
288 using Visit = std::function<StatusOr<Literal>(absl::Span<Literal>)>;
289 // Convenient specializations of Visit function for different operands.
290 using Visit0D = std::function<StatusOr<Literal>()>;
291 using Visit1D = std::function<StatusOr<Literal>(Literal)>;
292 using Visit2D = std::function<StatusOr<Literal>(Literal, Literal)>;
293
294 // A postorder dfs node can be visited once its dependency requests are all
295 // fulfilled.
296 struct [[nodiscard]] PostorderDFSNode {
AddDependencyxla::__anon1970efe50111::PostorderDFSNode297 PostorderDFSNode& AddDependency(int64_t handle, PostorderDFSNodeType type,
298 InferenceContext context,
299 std::string annotation = "") {
300 dependencies.emplace_back(handle, type, std::move(context),
301 std::move(annotation));
302 return *this;
303 }
304
AddVisitxla::__anon1970efe50111::PostorderDFSNode305 PostorderDFSNode& AddVisit(const Visit& visit) {
306 this->visit = visit;
307 return *this;
308 }
309
AddVisitxla::__anon1970efe50111::PostorderDFSNode310 PostorderDFSNode& AddVisit(const Visit0D& visit) {
311 this->visit = [visit](absl::Span<Literal> literals) { return visit(); };
312 return *this;
313 }
314
AddVisitxla::__anon1970efe50111::PostorderDFSNode315 PostorderDFSNode& AddVisit(const Visit1D& visit) {
316 this->visit = [visit](absl::Span<Literal> literals) {
317 return visit(std::move(literals[0]));
318 };
319 return *this;
320 }
321
AddVisitxla::__anon1970efe50111::PostorderDFSNode322 PostorderDFSNode& AddVisit(const Visit2D& visit) {
323 this->visit = [visit](absl::Span<Literal> literals) {
324 return visit(std::move(literals[0]), std::move(literals[1]));
325 };
326 return *this;
327 }
328 std::vector<PostorderDFSDep> dependencies;
329 Visit visit;
330 };
331
332 // Convert an interger handle to HloInstructionProto.
333 using HandleToInstruction =
334 std::function<StatusOr<const HloInstructionProto*>(int64_t)>;
335 using HandleToComputation = std::function<const HloComputationProto*(int64_t)>;
336
337 struct PostorderDFSVisitor {
PostorderDFSVisitorxla::__anon1970efe50111::PostorderDFSVisitor338 PostorderDFSVisitor(HloEvaluator& evaluator,
339 HandleToInstruction handle_to_instruction,
340 HandleToComputation handle_to_computation)
341 : evaluator(evaluator),
342 handle_to_instruction(handle_to_instruction),
343 handle_to_computation(handle_to_computation) {}
344
345 StatusOr<PostorderDFSNode> AnalyzeUpperBound(int64_t handle,
346 InferenceContext context);
347 StatusOr<PostorderDFSNode> AnalyzeLowerBound(int64_t handle,
348 InferenceContext context);
349 StatusOr<PostorderDFSNode> AnalyzeIsDynamic(int64_t handle,
350 PostorderDFSNodeType type,
351 InferenceContext context);
352 StatusOr<PostorderDFSNode> AnalyzeConstant(int64_t handle,
353 InferenceContext context);
354 StatusOr<PostorderDFSNode> AnalyzeConstantValueFallback(
355 int64_t handle, PostorderDFSNodeType type, InferenceContext context);
356
357 StatusOr<Literal> PostOrderDFSVisit(int64_t handle,
358 PostorderDFSNodeType type);
359
360 // Returns true if a value represented by `handle` is an integeral type or
361 // a floating pointer type that just got converted from an integral type.
362 // E.g.,:
363 // int(a) -> true
364 // float(int(a)) -> true
365 // float(a) -> false -- We don't know the concrete value of `a` at
366 // compile time, except for its type.
IsValueEffectiveIntegerxla::__anon1970efe50111::PostorderDFSVisitor367 bool IsValueEffectiveInteger(int64_t handle) {
368 // handle_to_instruction's failure status should be checked by parent.
369 const HloInstructionProto* instr =
370 handle_to_instruction(handle).ValueOrDie();
371 if (primitive_util::IsIntegralType(instr->shape().element_type())) {
372 return true;
373 }
374 // Also returns true if this is a convert that converts an integer to float.
375 HloOpcode opcode = StringToHloOpcode(instr->opcode()).ValueOrDie();
376 if (opcode != HloOpcode::kConvert) {
377 return false;
378 }
379 const HloInstructionProto* parent =
380 handle_to_instruction(instr->operand_ids(0)).ValueOrDie();
381 if (primitive_util::IsIntegralType(parent->shape().element_type())) {
382 return true;
383 }
384 return false;
385 }
386
387 // Checks the size of outputs and inputs. Returns true if any of them has size
388 // beyond kLargeShapeElementLimit and the instruction needs evaluation (e.g.,
389 // kGetDimensionSize or kSetDimensionSize doesn't need evaluation).
IsInstructionOverLimitxla::__anon1970efe50111::PostorderDFSVisitor390 bool IsInstructionOverLimit(const HloInstructionProto* proto,
391 InferenceContext context) {
392 Shape subshape =
393 ShapeUtil::GetSubshape(Shape(proto->shape()), context.shape_index);
394
395 if (subshape.IsArray() &&
396 ShapeUtil::ElementsIn(subshape) > kLargeShapeElementLimit) {
397 return true;
398 }
399 HloOpcode opcode = StringToHloOpcode(proto->opcode()).ValueOrDie();
400 for (int64_t operand_id : proto->operand_ids()) {
401 const HloInstructionProto* operand =
402 handle_to_instruction(operand_id).ValueOrDie();
403 Shape operand_shape = Shape(operand->shape());
404
405 if (operand_shape.IsArray() &&
406 ShapeUtil::ElementsIn(operand_shape) > kLargeShapeElementLimit &&
407 opcode != HloOpcode::kGetDimensionSize &&
408 opcode != HloOpcode::kSetDimensionSize) {
409 return true;
410 }
411 }
412 return false;
413 }
414
415 struct CacheKey {
CacheKeyxla::__anon1970efe50111::PostorderDFSVisitor::CacheKey416 CacheKey(int64_t handle, InferenceContext context,
417 PostorderDFSNodeType type)
418 : handle(handle), context(context), type(type) {}
419 int64_t handle;
420 InferenceContext context;
421 PostorderDFSNodeType type;
422
423 template <typename H>
AbslHashValuexla::__anon1970efe50111::PostorderDFSVisitor424 friend H AbslHashValue(H h, const CacheKey& key) {
425 h = H::combine(std::move(h), key.handle);
426 h = H::combine(std::move(h), key.context.shape_index.ToString());
427 h = H::combine(std::move(h),
428 VectorString(key.context.caller_operand_handles));
429 h = H::combine(std::move(h), key.type);
430 return h;
431 }
432
operator ==xla::__anon1970efe50111::PostorderDFSVisitor433 friend bool operator==(const CacheKey& lhs, const CacheKey& rhs) {
434 return lhs.handle == rhs.handle &&
435 lhs.context.shape_index == rhs.context.shape_index &&
436 lhs.context.caller_operand_handles ==
437 rhs.context.caller_operand_handles &&
438 lhs.type == rhs.type;
439 }
440 };
441
442 HloEvaluator& evaluator;
443 absl::flat_hash_map<CacheKey, Literal> evaluated;
444 HandleToInstruction handle_to_instruction;
445 HandleToComputation handle_to_computation;
446 // Give up when dealing with more than 1M elements.
447 static constexpr int64_t kLargeShapeElementLimit = 1000 * 1000;
448 };
449
450 // Returns a result representing that value is fully dynamic and can't be
451 // inferred. In other words, "give up" and return most conservative value.
CreateAllDynamicResult(Shape shape,PostorderDFSNodeType type)452 PostorderDFSNode CreateAllDynamicResult(Shape shape,
453 PostorderDFSNodeType type) {
454 return PostorderDFSNode().AddVisit(
455 [shape, type](absl::Span<Literal>) -> Literal {
456 if (type == PostorderDFSNodeType::kConstantValue ||
457 type == PostorderDFSNodeType::kConstantUpperBound ||
458 type == PostorderDFSNodeType::kConstantLowerBound) {
459 // When inferencing constant values, create garbage data, which will
460 // be masked out by dynamism counterpart.
461 return CreateGarbageLiteral(shape);
462 } else {
463 // When dynamism, return true, indicating all values are dynamic.
464 return CreatePredLiteral(true, shape);
465 }
466 });
467 }
468
469 } // namespace
470
471 // Analyze a tensor's constant value, upper-bound value or lower-bound value.
AnalyzeConstantValueFallback(int64_t handle,PostorderDFSNodeType type,InferenceContext context)472 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstantValueFallback(
473 int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
474 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
475 handle_to_instruction(handle));
476 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
477 Shape subshape =
478 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
479 PostorderDFSNode result;
480 // By default, the dependencies of current node are its operands.
481 for (auto operand_id : root->operand_ids()) {
482 InferenceContext dep_context = context;
483 dep_context.shape_index = {};
484 result.AddDependency(operand_id, type, dep_context);
485 }
486 switch (opcode) {
487 // Non functional ops.
488 case HloOpcode::kRng:
489 case HloOpcode::kAllReduce:
490 case HloOpcode::kReduceScatter:
491 case HloOpcode::kInfeed:
492 case HloOpcode::kOutfeed:
493 case HloOpcode::kRngBitGenerator:
494 case HloOpcode::kCustomCall:
495 case HloOpcode::kWhile:
496 case HloOpcode::kSend:
497 case HloOpcode::kRecv:
498 case HloOpcode::kSendDone:
499 case HloOpcode::kRecvDone:
500 case HloOpcode::kParameter: {
501 if (opcode == HloOpcode::kParameter &&
502 !context.caller_operand_handles.empty()) {
503 int64_t caller_operand = context.caller_operand_handles.back();
504 context.caller_operand_handles.pop_back();
505 return result.AddDependency(caller_operand, type, context)
506 .AddVisit([](Literal literal) { return literal; });
507 }
508 return CreateAllDynamicResult(subshape, type);
509 }
510 // Subtract and Divide use lower-bound as second operand.
511 case HloOpcode::kSubtract:
512 case HloOpcode::kCos:
513 case HloOpcode::kSin:
514 case HloOpcode::kNegate:
515 case HloOpcode::kAbs:
516 case HloOpcode::kDivide:
517 case HloOpcode::kGetDimensionSize: {
518 return InvalidArgument(
519 "AnalyzeConstantValueFallback can't handle opcode: %s",
520 root->opcode());
521 }
522 case HloOpcode::kCall: {
523 auto node = PostorderDFSNode();
524 auto* call_proto = root;
525 if (call_proto->operand_ids_size() != 1) {
526 // Only support single operand forwarding.
527 return CreateAllDynamicResult(subshape, type);
528 }
529 int64_t called_root =
530 handle_to_computation(call_proto->called_computation_ids(0))
531 ->root_id();
532 InferenceContext call_context = context;
533 call_context.caller_operand_handles.push_back(call_proto->operand_ids(0));
534 node.AddDependency(called_root, PostorderDFSNodeType::kConstantValue,
535 call_context, "callee's root instruction");
536 return node.AddVisit([](Literal operand) -> StatusOr<Literal> {
537 // Forward result of callee's root to caller.
538 return std::move(operand);
539 });
540 }
541
542 case HloOpcode::kConditional: {
543 auto node = PostorderDFSNode();
544 auto* conditional_proto = root;
545 InferenceContext predicate_context = context;
546 predicate_context.shape_index = {};
547 // Add dependencies to analyze the predicate of the conditional.
548 node.AddDependency(conditional_proto->operand_ids(0),
549 PostorderDFSNodeType::kConstantValue,
550 predicate_context)
551 .AddDependency(conditional_proto->operand_ids(0),
552 PostorderDFSNodeType::kValueIsDynamic,
553 predicate_context);
554 const int64_t branch_size =
555 conditional_proto->called_computation_ids_size();
556 for (int64_t i = 0; i < branch_size; ++i) {
557 int64_t branch_root =
558 handle_to_computation(conditional_proto->called_computation_ids(i))
559 ->root_id();
560 InferenceContext branch_context = context;
561 branch_context.caller_operand_handles.push_back(
562 conditional_proto->operand_ids(i + 1));
563 node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
564 branch_context);
565 }
566 return node.AddVisit(
567 [](absl::Span<Literal> operands) -> StatusOr<Literal> {
568 int64_t pred_is_dynamic = operands[1].Get<bool>({});
569 if (pred_is_dynamic) {
570 // If predicate is dynamic, return the value of the first branch
571 // -- If all branches return the same value, this is the value
572 // that we want; If not, the value will be masked anyway so the
573 // value inside doesn't matter.
574 return std::move(operands[2]);
575 } else {
576 // If predicate is static, return the value of the given branch.
577 int64_t branch_index = 0;
578 if (operands[0].shape().element_type() == PRED) {
579 if (operands[0].Get<bool>({})) {
580 branch_index = 0;
581 } else {
582 branch_index = 1;
583 }
584 } else {
585 branch_index = operands[0].GetIntegralAsS64({}).value();
586 }
587 const int64_t branch_dynamism_index = 2 + branch_index;
588 return std::move(operands[branch_dynamism_index]);
589 }
590 });
591 }
592 case HloOpcode::kGetTupleElement: {
593 int64_t operand_handle = root->operand_ids(0);
594 PostorderDFSNode result;
595 context.shape_index.push_front(root->tuple_index());
596 return PostorderDFSNode()
597 .AddDependency(operand_handle, type, context)
598 .AddVisit([](Literal operand) { return operand; });
599 }
600 case HloOpcode::kReduce:
601 case HloOpcode::kSort:
602 case HloOpcode::kScatter:
603 case HloOpcode::kReduceWindow: {
604 const HloComputationProto* computation_proto =
605 handle_to_computation(root->called_computation_ids(0));
606 return result.AddVisit(
607 [root, computation_proto, context,
608 this](absl::Span<Literal> operands) -> StatusOr<Literal> {
609 TF_ASSIGN_OR_RETURN(
610 auto computation,
611 HloComputation::CreateFromProto(*computation_proto, {}));
612 return HloProtoEvaluator(evaluator, *root)
613 .WithOperands(operands)
614 .WithComputation(std::move(computation))
615 .WithSubshape(context.shape_index)
616 .Evaluate();
617 });
618 }
619 default: {
620 if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
621 // There could be many operands of a tuple, but only one that we are
622 // interested in, represented by `tuple_operand_index`.
623 int64_t tuple_operand_index = context.shape_index.front();
624 InferenceContext tuple_operand_context = context;
625 tuple_operand_context.shape_index.pop_front();
626 return PostorderDFSNode()
627 .AddDependency(root->operand_ids(tuple_operand_index), type,
628 tuple_operand_context)
629 .AddVisit([](Literal operand) { return operand; });
630 }
631 return result.AddVisit([root, this](absl::Span<Literal> operands) {
632 return HloProtoEvaluator(evaluator, *root)
633 .WithOperands(operands)
634 .Evaluate();
635 });
636 }
637 }
638 }
639
AnalyzeUpperBound(int64_t handle,InferenceContext context)640 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeUpperBound(
641 int64_t handle, InferenceContext context) {
642 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
643 handle_to_instruction(handle));
644 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
645 Shape subshape =
646 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
647
648 if (IsInstructionOverLimit(root, context)) {
649 return CreateAllDynamicResult(subshape,
650 PostorderDFSNodeType::kConstantUpperBound);
651 }
652 switch (opcode) {
653 case HloOpcode::kGetDimensionSize: {
654 int64_t dimension = root->dimensions(0);
655 int64_t operand_handle = root->operand_ids(0);
656 const HloInstructionProto* operand_proto =
657 handle_to_instruction(operand_handle).ValueOrDie();
658 return PostorderDFSNode().AddVisit(
659 [operand_proto, dimension]() -> StatusOr<Literal> {
660 return LiteralUtil::CreateR0<int32_t>(
661 operand_proto->shape().dimensions(dimension));
662 });
663 }
664 case HloOpcode::kAbs: {
665 // upper-bound(abs(operand)) = max(abs(lower-bound(operand)),
666 // abs(upper-bound(operand)))
667 return PostorderDFSNode()
668 .AddDependency(root->operand_ids(0),
669 PostorderDFSNodeType::kConstantLowerBound, context)
670 .AddDependency(root->operand_ids(0),
671 PostorderDFSNodeType::kConstantUpperBound, context)
672 .AddVisit([this](Literal lower_bound,
673 Literal upper_bound) -> StatusOr<Literal> {
674 TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
675 evaluator.EvaluateElementwiseUnaryOp(
676 HloOpcode::kAbs, lower_bound));
677 TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
678 evaluator.EvaluateElementwiseUnaryOp(
679 HloOpcode::kAbs, upper_bound));
680 return evaluator.EvaluateElementwiseBinaryOp(
681 HloOpcode::kMaximum, lower_bound_abs, upper_bound_abs);
682 });
683 }
684 case HloOpcode::kSort: {
685 auto dfs = PostorderDFSNode();
686 InferenceContext dep_context = context;
687 dep_context.shape_index = {};
688 if (!context.shape_index.empty()) {
689 // Lazy evaluation: Only need to evaluate a subelement in a
690 // variadic-sort tensor.
691 dfs.AddDependency(root->operand_ids(context.shape_index[0]),
692 PostorderDFSNodeType::kConstantUpperBound,
693 dep_context);
694 } else {
695 for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
696 dfs.AddDependency(root->operand_ids(i),
697 PostorderDFSNodeType::kConstantUpperBound,
698 dep_context);
699 }
700 }
701
702 return dfs.AddVisit(
703 [root, context](absl::Span<Literal> operands) -> StatusOr<Literal> {
704 std::vector<Literal> results;
705 results.reserve(operands.size());
706 // Conservatively set each element of the tensor to the max value.
707 for (int64_t i = 0; i < operands.size(); ++i) {
708 auto max = LiteralUtil::MaxElement(operands[i]);
709 results.emplace_back(
710 max.Broadcast(operands[i].shape(), {}).ValueOrDie());
711 }
712 if (ShapeUtil::GetSubshape(Shape(root->shape()),
713 context.shape_index)
714 .IsTuple()) {
715 return LiteralUtil::MakeTupleOwned(std::move(results));
716 } else {
717 return std::move(results[0]);
718 }
719 });
720 }
721 case HloOpcode::kNegate: {
722 // upper-bound(negate(operand)) = negate(lower-bound(operand))
723 return PostorderDFSNode()
724 .AddDependency(root->operand_ids(0),
725 PostorderDFSNodeType::kConstantLowerBound, context)
726 .AddVisit([this](Literal lower_bound) -> StatusOr<Literal> {
727 return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
728 lower_bound);
729 });
730 }
731 case HloOpcode::kSubtract:
732 case HloOpcode::kDivide: {
733 // Lower-bound is used for second operand of subtract and divide.
734 return PostorderDFSNode()
735 .AddDependency(root->operand_ids(0),
736 PostorderDFSNodeType::kConstantUpperBound, context)
737 .AddDependency(root->operand_ids(1),
738 PostorderDFSNodeType::kConstantLowerBound, context)
739 .AddVisit([root, opcode, this](
740 Literal upper_bound,
741 Literal lower_bound) -> StatusOr<Literal> {
742 if (opcode == HloOpcode::kDivide &&
743 this->IsValueEffectiveInteger(root->operand_ids(1))) {
744 // Because in many cases the lower bound of a value is
745 // integer 0, instead of throwing an divide-by-zero error
746 // at compile time, we set the bound defer the check to
747 // runtime. In those cases we use the upper-bound of
748 // first operand as a placeholder.
749 auto zero = LiteralUtil::Zero(lower_bound.shape().element_type());
750 zero = zero.Broadcast(lower_bound.shape(), {}).ValueOrDie();
751 TF_ASSIGN_OR_RETURN(
752 auto lower_bound_is_zero,
753 evaluator.EvaluateElementwiseCompareOp(
754 ComparisonDirection::kEq, lower_bound, zero));
755
756 auto one = LiteralUtil::One(lower_bound.shape().element_type());
757 one = one.Broadcast(lower_bound.shape(), {}).ValueOrDie();
758 TF_ASSIGN_OR_RETURN(
759 lower_bound, evaluator.EvaluateElementwiseTernaryOp(
760 HloOpcode::kSelect, lower_bound_is_zero, one,
761 lower_bound));
762 }
763 std::vector<Literal> new_operands;
764 new_operands.emplace_back(std::move(upper_bound));
765 new_operands.emplace_back(std::move(lower_bound));
766 return HloProtoEvaluator(evaluator, *root)
767 .WithOperands(absl::MakeSpan(new_operands))
768 .Evaluate();
769 });
770 }
771 case HloOpcode::kCustomCall: {
772 if (root->custom_call_target() == "SetBound") {
773 return PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
774 if (root->literal().shape().element_type() == TUPLE) {
775 // First literal of SetBound contains bounds, second literal
776 // contains dynamism indicators.
777 return Literal::CreateFromProto(root->literal().tuple_literals(0));
778 } else {
779 return Literal::CreateFromProto(root->literal());
780 }
781 });
782 } else if (root->custom_call_target() == "Sharding") {
783 return PostorderDFSNode()
784 .AddDependency(root->operand_ids(0),
785 PostorderDFSNodeType::kConstantUpperBound, context)
786 .AddVisit([](Literal operand) { return operand; });
787 }
788 return InvalidArgument(
789 "Upper-bound inferencing on custom call %s is not supported",
790 root->DebugString());
791 }
792 case HloOpcode::kGather: {
793 return PostorderDFSNode()
794 .AddDependency(root->operand_ids(0),
795 PostorderDFSNodeType::kConstantUpperBound, context)
796 .AddDependency(root->operand_ids(1),
797 PostorderDFSNodeType::kConstantValue, context)
798 .AddVisit([root, this](absl::Span<Literal> operands) {
799 return HloProtoEvaluator(evaluator, *root)
800 .WithOperands(operands)
801 .Evaluate();
802 });
803 }
804 default:
805 return AnalyzeConstantValueFallback(
806 handle, PostorderDFSNodeType::kConstantUpperBound, context);
807 }
808 }
809
AnalyzeLowerBound(int64_t handle,InferenceContext context)810 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeLowerBound(
811 int64_t handle, InferenceContext context) {
812 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
813 handle_to_instruction(handle));
814 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
815 Shape subshape =
816 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
817 if (IsInstructionOverLimit(root, context)) {
818 return CreateAllDynamicResult(subshape,
819 PostorderDFSNodeType::kConstantLowerBound);
820 }
821 switch (opcode) {
822 case HloOpcode::kGetDimensionSize: {
823 int64_t dimension = root->dimensions(0);
824 int64_t operand_handle = root->operand_ids(0);
825 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
826 handle_to_instruction(operand_handle));
827 return PostorderDFSNode().AddVisit(
828 [dimension, operand_proto]() -> StatusOr<Literal> {
829 if (operand_proto->shape().is_dynamic_dimension(dimension)) {
830 return LiteralUtil::CreateR0<int32_t>(0);
831 } else {
832 return LiteralUtil::CreateR0<int32_t>(
833 operand_proto->shape().dimensions(dimension));
834 }
835 });
836 }
837 case HloOpcode::kAbs: {
838 // lower-bound(abs(operand)) = min(abs(lower-bound(operand)),
839 // abs(upper-bound(operand)))
840 return PostorderDFSNode()
841 .AddDependency(root->operand_ids(0),
842 PostorderDFSNodeType::kConstantLowerBound, context)
843 .AddDependency(root->operand_ids(0),
844 PostorderDFSNodeType::kConstantUpperBound, context)
845 .AddVisit([this](Literal lower_bound,
846 Literal upper_bound) -> StatusOr<Literal> {
847 TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
848 evaluator.EvaluateElementwiseUnaryOp(
849 HloOpcode::kAbs, lower_bound));
850 TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
851 evaluator.EvaluateElementwiseUnaryOp(
852 HloOpcode::kAbs, upper_bound));
853 return evaluator.EvaluateElementwiseBinaryOp(
854 HloOpcode::kMinimum, lower_bound_abs, upper_bound_abs);
855 });
856 }
857 case HloOpcode::kNegate: {
858 // lower-bound(negate(operand)) = negate(upper-bound(operand))
859 return PostorderDFSNode()
860 .AddDependency(root->operand_ids(0),
861 PostorderDFSNodeType::kConstantUpperBound, context)
862 .AddVisit([this](Literal upper_bound) -> StatusOr<Literal> {
863 return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
864 upper_bound);
865 });
866 }
867 case HloOpcode::kSubtract:
868 case HloOpcode::kDivide: {
869 // Upper bound is used for second operand of subtract and divide.
870 return PostorderDFSNode()
871 .AddDependency(root->operand_ids(0),
872 PostorderDFSNodeType::kConstantLowerBound, context)
873 .AddDependency(root->operand_ids(1),
874 PostorderDFSNodeType::kConstantUpperBound, context)
875 .AddVisit(
876 [root, this](absl::Span<Literal> operands) -> StatusOr<Literal> {
877 return HloProtoEvaluator(evaluator, *root)
878 .WithOperands(operands)
879 .Evaluate();
880 });
881 }
882 case HloOpcode::kGather: {
883 return PostorderDFSNode()
884 .AddDependency(root->operand_ids(0),
885 PostorderDFSNodeType::kConstantLowerBound, context)
886 .AddDependency(root->operand_ids(1),
887 PostorderDFSNodeType::kConstantValue, context)
888 .AddVisit([root, this](absl::Span<Literal> operands) {
889 return HloProtoEvaluator(evaluator, *root)
890 .WithOperands(operands)
891 .Evaluate();
892 });
893 }
894 default:
895 return AnalyzeConstantValueFallback(
896 handle, PostorderDFSNodeType::kConstantLowerBound, context);
897 }
898 }
899
AnalyzeConstant(int64_t handle,InferenceContext context)900 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstant(
901 int64_t handle, InferenceContext context) {
902 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
903 handle_to_instruction(handle));
904 HloOpcode opcode = StringToHloOpcode(root->opcode()).ValueOrDie();
905 Shape subshape =
906 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
907 if (IsInstructionOverLimit(root, context)) {
908 return CreateAllDynamicResult(subshape,
909 PostorderDFSNodeType::kConstantValue);
910 }
911 switch (opcode) {
912 case HloOpcode::kGetDimensionSize: {
913 int64_t dimension = root->dimensions(0);
914 int64_t operand_handle = root->operand_ids(0);
915 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
916 handle_to_instruction(operand_handle));
917 return PostorderDFSNode().AddVisit(
918 [operand_proto, dimension, root]() -> StatusOr<Literal> {
919 if (operand_proto->shape().is_dynamic_dimension(dimension)) {
920 // The value is dynamic, we return garbage data here and mask them
921 // out later.
922 return CreateGarbageLiteral(Shape(root->shape()));
923 } else {
924 return LiteralUtil::CreateR0<int32_t>(
925 operand_proto->shape().dimensions(dimension));
926 }
927 });
928 }
929 case HloOpcode::kSubtract:
930 case HloOpcode::kCos:
931 case HloOpcode::kSin:
932 case HloOpcode::kNegate:
933 case HloOpcode::kAbs:
934 case HloOpcode::kDivide: {
935 PostorderDFSNode result;
936 for (auto operand_id : root->operand_ids()) {
937 result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
938 context);
939 }
940 return result.AddVisit(
941 [root, this](absl::Span<Literal> operands) -> StatusOr<Literal> {
942 return HloProtoEvaluator(evaluator, *root)
943 .WithOperands(operands)
944 .Evaluate();
945 });
946 }
947 case HloOpcode::kCustomCall: {
948 if (root->custom_call_target() == "SetBound") {
949 // `SetBound` doesn't change the static value of a tensor, so forward
950 // the operand when analyzing static value.
951 return PostorderDFSNode()
952 .AddDependency(root->operand_ids(0),
953 PostorderDFSNodeType::kConstantValue, context)
954 .AddVisit(
955 [](Literal operand) -> StatusOr<Literal> { return operand; });
956 } else if (root->custom_call_target() == "Sharding") {
957 return PostorderDFSNode()
958 .AddDependency(root->operand_ids(0),
959 PostorderDFSNodeType::kConstantValue, context)
960 .AddVisit([](Literal operand) { return operand; });
961 } else {
962 return PostorderDFSNode().AddVisit(
963 [root, context](absl::Span<Literal>) {
964 // The value is dynamic. We return a garbage literal here, which
965 // will be masked out later.
966 return CreateGarbageLiteral(ShapeUtil::GetSubshape(
967 Shape(root->shape()), context.shape_index));
968 });
969 }
970 }
971 case HloOpcode::kSort: {
972 PostorderDFSNode result;
973 InferenceContext dep_context = context;
974 dep_context.shape_index = {};
975 for (auto operand_id : root->operand_ids()) {
976 result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
977 dep_context);
978 }
979 const HloComputationProto* computation_proto =
980 handle_to_computation(root->called_computation_ids(0));
981 return result.AddVisit(
982 [root, context, computation_proto,
983 this](absl::Span<Literal> operands) -> StatusOr<Literal> {
984 TF_ASSIGN_OR_RETURN(
985 auto computation,
986 HloComputation::CreateFromProto(*computation_proto, {}));
987 return HloProtoEvaluator(evaluator, *root)
988 .WithOperands(operands)
989 .WithComputation(std::move(computation))
990 .WithSubshape(context.shape_index)
991 .Evaluate();
992 });
993 }
994 default:
995 return AnalyzeConstantValueFallback(
996 handle, PostorderDFSNodeType::kConstantValue, context);
997 }
998 }
999
AnalyzeIsDynamic(int64_t handle,PostorderDFSNodeType type,InferenceContext context)1000 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeIsDynamic(
1001 int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
1002 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
1003 handle_to_instruction(handle));
1004 // Invariant check.
1005 TF_RET_CHECK(root);
1006 VLOG(1) << "Analyzing IsDynamic on " << root->DebugString();
1007 Shape subshape =
1008 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
1009 if (IsInstructionOverLimit(root, context)) {
1010 return CreateAllDynamicResult(subshape, type);
1011 }
1012 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
1013 PostorderDFSNode result;
1014 for (auto operand_id : root->operand_ids()) {
1015 InferenceContext dep_context = context;
1016 dep_context.shape_index = {};
1017 result.AddDependency(operand_id, type, dep_context);
1018 }
1019 switch (opcode) {
1020 case HloOpcode::kGetDimensionSize: {
1021 int64_t dimension = root->dimensions(0);
1022 int64_t operand_handle = root->operand_ids(0);
1023 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
1024 handle_to_instruction(operand_handle));
1025 return PostorderDFSNode().AddVisit(
1026 [operand_proto, dimension, type]() -> StatusOr<Literal> {
1027 if (type == PostorderDFSNodeType::kBoundIsDynamic) {
1028 // The bound of dynamic dimension is not dynamic.
1029 return LiteralUtil::CreateR0<bool>(false);
1030 }
1031 // The value of dynamic dimension is dynamic.
1032 return LiteralUtil::CreateR0<bool>(
1033 operand_proto->shape().is_dynamic_dimension(dimension));
1034 });
1035 }
1036 case HloOpcode::kSort: {
1037 auto dfs = PostorderDFSNode();
1038 InferenceContext dep_context = context;
1039 dep_context.shape_index = {};
1040
1041 for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
1042 dfs.AddDependency(root->operand_ids(i), type, dep_context);
1043 }
1044
1045 return dfs.AddVisit([root, context, type](absl::Span<Literal> operands)
1046 -> StatusOr<Literal> {
1047 bool all_operands_values_static = true;
1048 for (int64_t i = 0; i < operands.size(); ++i) {
1049 all_operands_values_static &= operands[i].IsAll(0);
1050 }
1051 if (type == PostorderDFSNodeType::kValueIsDynamic) {
1052 // If there is a single operand of a sort is dynamic, we
1053 // conservatively say all results are dynamic.
1054 return CreatePredLiteral(!all_operands_values_static,
1055 ShapeUtil::GetSubshape(Shape(root->shape()),
1056 context.shape_index));
1057 }
1058 CHECK(type == PostorderDFSNodeType::kBoundIsDynamic);
1059 // The condition for bounds are more relaxed than values. If we know the
1060 // bounds of each element [B0, B1... Bn], all results have the same
1061 // bound
1062 // [max(B0, B1...), max(B0, B1...), ...]
1063 if (!context.shape_index.empty()) {
1064 int64_t index = context.shape_index[0];
1065 bool all_values_static = operands[index].IsAll(0);
1066 return CreatePredLiteral(!all_values_static, operands[index].shape());
1067 }
1068
1069 std::vector<Literal> results;
1070 results.reserve(operands.size());
1071 for (int64_t i = 0; i < operands.size(); ++i) {
1072 bool all_values_static = operands[i].IsAll(0);
1073 results.emplace_back(
1074 CreatePredLiteral(!all_values_static, operands[i].shape()));
1075 }
1076 if (!ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index)
1077 .IsTuple()) {
1078 return std::move(results[0]);
1079 }
1080 return LiteralUtil::MakeTupleOwned(std::move(results));
1081 });
1082 }
1083 case HloOpcode::kSetDimensionSize:
1084 return result.AddVisit([root, type](absl::Span<Literal> operands) {
1085 bool any_dynamic_operand = absl::c_any_of(
1086 operands, [](Literal& operand) { return !operand.IsAll(0); });
1087 // If values in a tensor `t` with bound are [e0, e1, e2...], we can say
1088 // the max value of each position is [max(t), max(t), max(t), ...]. The
1089 // effective size of this tensor doesn't change the max value.
1090 return CreatePredLiteral(
1091 type == PostorderDFSNodeType::kValueIsDynamic &&
1092 any_dynamic_operand,
1093 ShapeUtil::MakeStaticShape(Shape(root->shape())));
1094 });
1095 case HloOpcode::kDynamicSlice: {
1096 return result.AddVisit([root](absl::Span<Literal> operands) {
1097 // If any of the operand is dynamic, we say output is dynamic.
1098 bool any_dynamic_operand = absl::c_any_of(
1099 operands, [](Literal& operand) { return !operand.IsAll(0); });
1100 return CreatePredLiteral(any_dynamic_operand, Shape(root->shape()));
1101 });
1102 }
1103 case HloOpcode::kAbs:
1104 case HloOpcode::kRoundNearestAfz:
1105 case HloOpcode::kRoundNearestEven:
1106 case HloOpcode::kBitcast:
1107 case HloOpcode::kCeil:
1108 case HloOpcode::kCollectivePermuteDone:
1109 case HloOpcode::kCos:
1110 case HloOpcode::kClz:
1111 case HloOpcode::kExp:
1112 case HloOpcode::kExpm1:
1113 case HloOpcode::kFloor:
1114 case HloOpcode::kImag:
1115 case HloOpcode::kIsFinite:
1116 case HloOpcode::kLog:
1117 case HloOpcode::kLog1p:
1118 case HloOpcode::kNot:
1119 case HloOpcode::kNegate:
1120 case HloOpcode::kPopulationCount:
1121 case HloOpcode::kReal:
1122 case HloOpcode::kRsqrt:
1123 case HloOpcode::kLogistic:
1124 case HloOpcode::kSign:
1125 case HloOpcode::kSin:
1126 case HloOpcode::kConvert:
1127 case HloOpcode::kSqrt:
1128 case HloOpcode::kCbrt:
1129 case HloOpcode::kTanh: {
1130 // Forward operand as they don't change if a value is dynamic or static.
1131 return result.AddVisit([](Literal operand) { return operand; });
1132 }
1133 case HloOpcode::kAdd:
1134 case HloOpcode::kAtan2:
1135 case HloOpcode::kDivide:
1136 case HloOpcode::kComplex:
1137 case HloOpcode::kMaximum:
1138 case HloOpcode::kMinimum:
1139 case HloOpcode::kMultiply:
1140 case HloOpcode::kPower:
1141 case HloOpcode::kRemainder:
1142 case HloOpcode::kSubtract:
1143 case HloOpcode::kCompare:
1144 case HloOpcode::kAnd:
1145 case HloOpcode::kOr:
1146 case HloOpcode::kXor:
1147 case HloOpcode::kShiftLeft:
1148 case HloOpcode::kShiftRightArithmetic:
1149 case HloOpcode::kShiftRightLogical: {
1150 return result.AddVisit([root, this](absl::Span<Literal> operands) {
1151 return HloProtoEvaluator(evaluator, *root)
1152 .WithOperands(operands)
1153 .WithPrimitiveType(PRED)
1154 .WithOpCode(HloOpcode::kOr)
1155 .Evaluate();
1156 });
1157 }
1158 case HloOpcode::kTuple:
1159 case HloOpcode::kTranspose:
1160 case HloOpcode::kSlice:
1161 case HloOpcode::kBroadcast:
1162 case HloOpcode::kReverse:
1163 case HloOpcode::kConcatenate:
1164 case HloOpcode::kReshape:
1165 case HloOpcode::kPad: {
1166 if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
1167 // There could be many operands of a tuple, but only one that we are
1168 // interested in, represented by `tuple_operand_index`.
1169 int64_t tuple_operand_index = context.shape_index.front();
1170 InferenceContext tuple_operand_context = context;
1171 tuple_operand_context.shape_index.pop_front();
1172 return PostorderDFSNode()
1173 .AddDependency(root->operand_ids(tuple_operand_index), type,
1174 tuple_operand_context)
1175 .AddVisit([](Literal operand) { return operand; });
1176 }
1177 return result.AddVisit([root, this](absl::Span<Literal> operands) {
1178 return HloProtoEvaluator(evaluator, *root)
1179 .WithOperands(operands)
1180 .WithPrimitiveType(PRED)
1181 .Evaluate();
1182 });
1183 }
1184 case HloOpcode::kCall: {
1185 auto node = PostorderDFSNode();
1186 auto* call_proto = root;
1187
1188 if (call_proto->operand_ids_size() != 1) {
1189 // Only support single operand forwarding.
1190 return CreateAllDynamicResult(subshape, type);
1191 }
1192 int64_t call_root =
1193 handle_to_computation(call_proto->called_computation_ids(0))
1194 ->root_id();
1195 InferenceContext branch_context = context;
1196 branch_context.caller_operand_handles.push_back(
1197 call_proto->operand_ids(0));
1198 node.AddDependency(call_root, PostorderDFSNodeType::kValueIsDynamic,
1199 branch_context, "callee's root instruction");
1200 return node.AddVisit([context](Literal operand) -> StatusOr<Literal> {
1201 // Forward result of callee's root to caller.
1202 return operand;
1203 });
1204 }
1205 case HloOpcode::kConditional: {
1206 auto node = PostorderDFSNode();
1207 auto* conditional_proto = root;
1208 InferenceContext predicate_context = context;
1209 predicate_context.shape_index = {};
1210 // Add dependencies to analyze the predicate of the conditional.
1211 node.AddDependency(conditional_proto->operand_ids(0),
1212 PostorderDFSNodeType::kConstantValue,
1213 predicate_context)
1214 .AddDependency(conditional_proto->operand_ids(0),
1215 PostorderDFSNodeType::kValueIsDynamic,
1216 predicate_context);
1217 const int64_t branch_size =
1218 conditional_proto->called_computation_ids_size();
1219 for (int64_t i = 0; i < branch_size; ++i) {
1220 int64_t branch_root =
1221 handle_to_computation(conditional_proto->called_computation_ids(i))
1222 ->root_id();
1223 InferenceContext branch_context = context;
1224 branch_context.caller_operand_handles.push_back(
1225 conditional_proto->operand_ids(i + 1));
1226 node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
1227 branch_context,
1228 absl::StrFormat("branch %lld's value", i))
1229 .AddDependency(branch_root, PostorderDFSNodeType::kValueIsDynamic,
1230 branch_context,
1231 absl::StrFormat("branch %lld's dynamism", i));
1232 }
1233 // Predicate uses 2 dependencies:
1234 // 0: Predicate value.
1235 // 1: Predicate is dynamic.
1236 // Each branch i has 2 dependenices:
1237 // 2*i: Branch result value
1238 // 2*i + 1: Branch value is dynamic.
1239 return node.AddVisit([root, branch_size,
1240 context](absl::Span<Literal> operands)
1241 -> StatusOr<Literal> {
1242 int64_t pred_is_dynamic = operands[1].Get<bool>({});
1243 auto result = CreatePredLiteral(
1244 true,
1245 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1246 if (pred_is_dynamic) {
1247 VLOG(1) << "predict is dynamic value" << result.ToString();
1248 // If predicate is dynamic, the result is only static if all
1249 // branches are static and return the same value.
1250 result.MutableEachCell<bool>(
1251 [&](absl::Span<const int64_t> indices, bool value) {
1252 std::string branch_value = operands[2].GetAsString(indices, {});
1253 for (int64_t i = 0; i < branch_size; ++i) {
1254 const int64_t branch_value_index = 2 + 2 * i;
1255 const int64_t branch_dynamism_index = 2 + 2 * i + 1;
1256 auto branch_is_dynamic =
1257 operands[branch_dynamism_index].Get<bool>(indices);
1258 if (branch_is_dynamic) {
1259 return true;
1260 }
1261
1262 if (branch_value !=
1263 operands[branch_value_index].GetAsString(indices, {})) {
1264 return true;
1265 }
1266 }
1267 // Value of the branch is static.
1268 return false;
1269 });
1270 return result;
1271 } else {
1272 VLOG(1) << "predict is constant value";
1273 // If predicate is static, return true if given branch result
1274 // value is dynamic.
1275 int64_t branch_index = 0;
1276 if (operands[0].shape().element_type() == PRED) {
1277 if (operands[0].Get<bool>({})) {
1278 branch_index = 0;
1279 } else {
1280 branch_index = 1;
1281 }
1282 } else {
1283 branch_index = operands[0].GetIntegralAsS64({}).value();
1284 }
1285 const int64_t branch_dynamism_index = 2 + 2 * branch_index + 1;
1286 return std::move(operands[branch_dynamism_index]);
1287 }
1288 });
1289 }
1290 case HloOpcode::kGetTupleElement: {
1291 int64_t operand_handle = root->operand_ids(0);
1292 PostorderDFSNode result;
1293 context.shape_index.push_front(root->tuple_index());
1294 return PostorderDFSNode()
1295 .AddDependency(operand_handle, type, context)
1296 .AddVisit([](Literal operand) { return operand; });
1297 }
1298
1299 case HloOpcode::kReduce: {
1300 return result.AddVisit(
1301 [root, context, this](absl::Span<Literal> operands) {
1302 Shape root_shape = Shape(root->shape());
1303 Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
1304 std::unique_ptr<HloComputation> reduce_or;
1305 if (root_shape.IsTuple()) {
1306 // Variadic reduce.
1307 HloComputation::Builder b("reduce_or");
1308 // Assuming all operands interact with each other. This could be
1309 // overly conservative. If needed, a dataflow analysis could be
1310 // performed in the future.
1311 //
1312 // The value starts with `false` (static) and will be `or`ed with
1313 // all operands's dynamism.
1314 auto accum = b.AddInstruction(HloInstruction::CreateConstant(
1315 LiteralUtil::CreateR0<bool>(false)));
1316
1317 for (int i = 0; i < root_shape.tuple_shapes_size(); ++i) {
1318 auto lhs = b.AddInstruction(
1319 HloInstruction::CreateParameter(i, scalar_shape, "lhs"));
1320 auto rhs = b.AddInstruction(HloInstruction::CreateParameter(
1321 i + root_shape.tuple_shapes_size(), scalar_shape, "rhs"));
1322 accum = b.AddInstruction(HloInstruction::CreateBinary(
1323 scalar_shape, HloOpcode::kOr, accum, lhs));
1324 accum = b.AddInstruction(HloInstruction::CreateBinary(
1325 scalar_shape, HloOpcode::kOr, accum, rhs));
1326 }
1327 // `Broadcast` the result to all positions in the result.
1328 std::vector<HloInstruction*> results(
1329 root_shape.tuple_shapes_size(), accum);
1330 b.AddInstruction(HloInstruction::CreateTuple(results));
1331 reduce_or = b.Build();
1332 } else {
1333 HloComputation::Builder b("reduce_or");
1334 auto lhs = b.AddInstruction(
1335 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1336 auto rhs = b.AddInstruction(
1337 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1338 b.AddInstruction(HloInstruction::CreateBinary(
1339 scalar_shape, HloOpcode::kOr, lhs, rhs));
1340 reduce_or = b.Build();
1341 }
1342
1343 return HloProtoEvaluator(evaluator, *root)
1344 .WithOperands(operands)
1345 .WithPrimitiveType(PRED)
1346 .WithComputation(std::move(reduce_or))
1347 // Reduce could produce tuple shape, only fetch what we need.
1348 .WithSubshape(context.shape_index)
1349 .Evaluate();
1350 });
1351 }
1352 case HloOpcode::kConstant:
1353 case HloOpcode::kIota: {
1354 return result.AddVisit(
1355 [root]() { return CreatePredLiteral(false, Shape(root->shape())); });
1356 }
1357 case HloOpcode::kParameter: {
1358 if (opcode == HloOpcode::kParameter &&
1359 !context.caller_operand_handles.empty()) {
1360 int64_t caller_operand = context.caller_operand_handles.back();
1361 context.caller_operand_handles.pop_back();
1362 return result.AddDependency(caller_operand, type, context)
1363 .AddVisit([](Literal literal) { return literal; });
1364 }
1365 return result.AddVisit([root, context]() {
1366 return CreatePredLiteral(
1367 true,
1368 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1369 });
1370 }
1371 case HloOpcode::kSelect: {
1372 return PostorderDFSNode()
1373 .AddDependency(root->operand_ids(0),
1374 PostorderDFSNodeType::kConstantValue, context)
1375 .AddDependency(root->operand_ids(0),
1376 PostorderDFSNodeType::kValueIsDynamic, context)
1377 // lhs dependency.
1378 .AddDependency(root->operand_ids(1), type, context)
1379 // rhs dependency.
1380 .AddDependency(root->operand_ids(2), type, context)
1381 .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
1382 OptionalLiteral optional_selector_literal(std::move(operands[0]),
1383 std::move(operands[1]));
1384 Literal lhs = std::move(operands[2]);
1385 Literal rhs = std::move(operands[3]);
1386 auto result = CreatePredLiteral(true, Shape(root->shape()));
1387 result.MutableEachCell<bool>(
1388 [&](absl::Span<const int64_t> indices, bool value) {
1389 std::optional<bool> optional_selector =
1390 optional_selector_literal.Get<bool>(indices);
1391
1392 bool lhs_value = lhs.Get<bool>(indices);
1393 bool rhs_value = rhs.Get<bool>(indices);
1394 if (optional_selector.has_value()) {
1395 // Manually evaluate the selection without using Evaluator.
1396 if (*optional_selector) {
1397 return lhs_value;
1398 } else {
1399 return rhs_value;
1400 }
1401 } else {
1402 // Conservatively assume value is dynamic if selector is
1403 // dynamic.
1404 return true;
1405 }
1406 });
1407 return result;
1408 });
1409 }
1410 case HloOpcode::kGather: {
1411 return PostorderDFSNode()
1412 .AddDependency(root->operand_ids(0), type, context)
1413 .AddDependency(root->operand_ids(1),
1414 PostorderDFSNodeType::kConstantValue, context)
1415 .AddDependency(root->operand_ids(1),
1416 PostorderDFSNodeType::kValueIsDynamic, context)
1417 .AddVisit(
1418 [root, this](absl::Span<Literal> operands) -> StatusOr<Literal> {
1419 OptionalLiteral optional_selector_literal(
1420 std::move(operands[1]), std::move(operands[2]));
1421
1422 if (!optional_selector_literal.AllValid()) {
1423 // Conservatively assume results are dynamic.
1424 return CreatePredLiteral(true, Shape(root->shape()));
1425 }
1426 std::vector<Literal> new_operands;
1427 new_operands.emplace_back(std::move(operands[0]));
1428 new_operands.emplace_back(
1429 optional_selector_literal.GetValue()->Clone());
1430
1431 return HloProtoEvaluator(evaluator, *root)
1432 .WithOperands(absl::MakeSpan(new_operands))
1433 .WithPrimitiveType(PRED)
1434 .Evaluate();
1435 });
1436 }
1437 case HloOpcode::kCustomCall: {
1438 if (root->custom_call_target() == "SetBound") {
1439 return PostorderDFSNode().AddVisit([type, root]() -> StatusOr<Literal> {
1440 if (type == PostorderDFSNodeType::kBoundIsDynamic) {
1441 return CreatePredLiteral(false, Shape(root->shape()));
1442 } else {
1443 if (root->literal().shape().element_type() == TUPLE) {
1444 // First literal of SetBound contains bounds, second literal
1445 // contains dynamism indicators.
1446 return Literal::CreateFromProto(
1447 root->literal().tuple_literals(1));
1448 } else if (type == PostorderDFSNodeType::kValueIsDynamic) {
1449 return CreatePredLiteral(true, Shape(root->shape()));
1450 } else {
1451 return Literal::CreateFromProto(root->literal());
1452 }
1453 }
1454 });
1455 } else if (root->custom_call_target() == "Sharding") {
1456 return result.AddVisit([](Literal operand) { return operand; });
1457 } else {
1458 return InvalidArgument(
1459 "Dynamic inferencing on custom call %s is not supported",
1460 root->DebugString());
1461 }
1462
1463 break;
1464 }
1465
1466 case HloOpcode::kRecv:
1467 case HloOpcode::kRecvDone:
1468 case HloOpcode::kSend:
1469 case HloOpcode::kSendDone:
1470 case HloOpcode::kWhile: {
1471 return PostorderDFSNode().AddVisit([root,
1472 context]() -> StatusOr<Literal> {
1473 return CreatePredLiteral(
1474 true,
1475 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1476 });
1477 break;
1478 }
1479 default:
1480 return PostorderDFSNode().AddVisit([root,
1481 context]() -> StatusOr<Literal> {
1482 return CreatePredLiteral(
1483 true,
1484 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1485 });
1486 }
1487 }
1488
PostOrderDFSVisit(int64_t handle,PostorderDFSNodeType type)1489 StatusOr<Literal> PostorderDFSVisitor::PostOrderDFSVisit(
1490 int64_t handle, PostorderDFSNodeType type) {
1491 enum VisitState {
1492 kUnvisited = 0,
1493 kVisiting,
1494 kVisited,
1495 };
1496
1497 int64_t unique_id = 0;
1498 struct WorkItem {
1499 explicit WorkItem(int64_t handle, InferenceContext context,
1500 PostorderDFSNodeType type, VisitState state, int64_t id)
1501 : handle(handle),
1502 context(std::move(context)),
1503 type(type),
1504 state(state),
1505 id(id) {}
1506 int64_t handle; // Handle of the node in the graph.
1507 InferenceContext context;
1508 PostorderDFSNodeType type;
1509 VisitState state;
1510 Visit visit; // The handler to call once the dependencies are resolved into
1511 // literal form.
1512 int64_t id; // Unique id in the work queue, starting from 0.
1513 std::vector<CacheKey> dependencies;
1514
1515 CacheKey GetCacheKey() { return CacheKey(handle, context, type); }
1516 };
1517
1518 std::vector<WorkItem> stack;
1519 WorkItem root(handle, InferenceContext({}, {}), type, kUnvisited,
1520 unique_id++);
1521 stack.push_back(root);
1522 while (!stack.empty()) {
1523 WorkItem& item = stack.back();
1524 VLOG(1) << "stack top shape index: " << item.context.shape_index.ToString();
1525 if (VLOG_IS_ON(1)) {
1526 TF_RETURN_IF_ERROR(handle_to_instruction(item.handle).status());
1527 VLOG(1) << "stack top "
1528 << handle_to_instruction(item.handle).ValueOrDie()->DebugString();
1529 }
1530 if (item.state == kVisiting) {
1531 VLOG(1) << "visiting";
1532 // The dependencies are ready, visit the node itself.
1533
1534 // Gather dependencies and transform them into literals.
1535 std::vector<Literal> literals;
1536 literals.reserve(item.dependencies.size());
1537 for (CacheKey& dep_key : item.dependencies) {
1538 TF_RET_CHECK(evaluated.contains(dep_key));
1539 literals.emplace_back(evaluated.at(dep_key).Clone());
1540 }
1541 VLOG(1) << "Start visiting with dependency type: "
1542 << PostorderDFSNodeTypeToString(item.type);
1543 TF_ASSIGN_OR_RETURN(auto literal, item.visit(absl::MakeSpan(literals)));
1544 VLOG(1) << "End visiting: " << literal.ToString();
1545 evaluated[item.GetCacheKey()] = std::move(literal);
1546 stack.pop_back();
1547 continue;
1548 }
1549 // This is the first time we see this node, we want to gather its
1550 // dependenceis.
1551 VLOG(1) << "unvisited";
1552 if (evaluated.contains(item.GetCacheKey())) {
1553 stack.pop_back();
1554 continue;
1555 }
1556 item.state = kVisiting;
1557 PostorderDFSNode node;
1558 switch (item.type) {
1559 case PostorderDFSNodeType::kConstantValue: {
1560 VLOG(1) << "constant value";
1561 TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle, item.context));
1562 break;
1563 }
1564 case PostorderDFSNodeType::kConstantLowerBound: {
1565 VLOG(1) << "constant lower bound";
1566 TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle, item.context));
1567 break;
1568 }
1569 case PostorderDFSNodeType::kConstantUpperBound: {
1570 VLOG(1) << "constant upper bound";
1571 TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle, item.context));
1572 break;
1573 }
1574 case PostorderDFSNodeType::kBoundIsDynamic:
1575 case PostorderDFSNodeType::kValueIsDynamic: {
1576 VLOG(1) << "value is dynamic";
1577 TF_ASSIGN_OR_RETURN(
1578 node, AnalyzeIsDynamic(item.handle, item.type, item.context));
1579 break;
1580 }
1581 }
1582 // Store the visit function which is needed when its dependencies are
1583 // resolved.
1584 item.visit = node.visit;
1585
1586 const int64_t current_item_id = stack.size() - 1;
1587 // Enqueue dependencies into the stack. `item` shouldn't be accessed after
1588 // this point.
1589 for (const PostorderDFSDep& dep : node.dependencies) {
1590 TF_ASSIGN_OR_RETURN(auto dependency_inst,
1591 handle_to_instruction(dep.handle));
1592 VLOG(1) << "dependency " << dep.annotation
1593 << "::" << dependency_inst->DebugString() << "index"
1594 << dep.context.shape_index << " stack size:" << stack.size();
1595 stack.emplace_back(dep.handle, dep.context, dep.type, kUnvisited,
1596 unique_id++);
1597 stack[current_item_id].dependencies.push_back(stack.back().GetCacheKey());
1598 }
1599 }
1600 VLOG(1) << "done" << evaluated[root.GetCacheKey()].ToString();
1601 return evaluated[root.GetCacheKey()].Clone();
1602 }
1603
AnalyzeIsDynamic(XlaOp op)1604 StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
1605 PostorderDFSVisitor visitor(
1606 evaluator_,
1607 [&](int64_t handle) {
1608 return builder_->LookUpInstructionByHandle(handle);
1609 },
1610 [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1611
1612 auto result = visitor.PostOrderDFSVisit(
1613 op.handle(), PostorderDFSNodeType::kValueIsDynamic);
1614 return result;
1615 }
1616
CseOpHandle(int64_t handle)1617 StatusOr<std::optional<int64_t>> ValueInference::CseOpHandle(int64_t handle) {
1618 TF_ASSIGN_OR_RETURN(auto inst, builder_->LookUpInstructionByHandle(handle));
1619 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode()));
1620 // For now, only handle kGetDimensionSize as that's the most duplicated one.
1621 if (opcode != HloOpcode::kGetDimensionSize) {
1622 return {std::nullopt};
1623 }
1624 int64_t hash = absl::HashOf(inst->operand_ids(0), inst->dimensions(0));
1625 auto lookup = cse_map_.find(hash);
1626 if (lookup == cse_map_.end()) {
1627 cse_map_[hash] = handle;
1628 return {std::nullopt};
1629 }
1630 TF_ASSIGN_OR_RETURN(auto equivalent_op,
1631 builder_->LookUpInstructionByHandle(lookup->second));
1632 // Check that the op is indeed equivalent to prevent hash collision --
1633 // relatively easy to happen with 64 bits hash.
1634 if (equivalent_op->opcode() != inst->opcode() ||
1635 equivalent_op->operand_ids(0) != inst->operand_ids(0) ||
1636 equivalent_op->dimensions(0) != inst->dimensions(0)) {
1637 // Hash collision, don't CSE.
1638 return {std::nullopt};
1639 }
1640 int64_t cse = lookup->second;
1641 if (handle != cse) {
1642 // Successfully found a handle that's not the same as input but equivalent.
1643 return {cse};
1644 }
1645 return {std::nullopt};
1646 }
1647
SimplifyOp(int64_t handle)1648 StatusOr<Literal> ValueInference::SimplifyOp(int64_t handle) {
1649 TF_ASSIGN_OR_RETURN(auto cse_handle, CseOpHandle(handle));
1650 if (cse_handle) {
1651 // Use the CSE'd handle instead.
1652 return SimplifyOp(*cse_handle);
1653 }
1654 TF_ASSIGN_OR_RETURN(auto* inst, builder_->LookUpInstructionByHandle(handle));
1655 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode()));
1656 std::vector<Literal> operands;
1657 auto output_shape = Shape(inst->shape());
1658 switch (opcode) {
1659 case HloOpcode::kSlice:
1660 case HloOpcode::kConcatenate:
1661 case HloOpcode::kReshape:
1662 case HloOpcode::kBroadcast: {
1663 for (auto operand_id : inst->operand_ids()) {
1664 TF_ASSIGN_OR_RETURN(auto literal, SimplifyOp(operand_id));
1665 operands.emplace_back(std::move(literal));
1666 }
1667 // We put handles into the tensor and evaluate the results into a literal.
1668 // The literal also contain handles for each element position.
1669 return HloProtoEvaluator(evaluator_, *inst)
1670 .WithOperands(absl::MakeSpan(operands))
1671 .WithPrimitiveType(S64)
1672 .Evaluate();
1673 }
1674 case HloOpcode::kConvert: {
1675 // Only identity kConvert can be optimized away.
1676 auto operand = builder_->LookUpInstructionByHandle(inst->operand_ids(0))
1677 .ValueOrDie();
1678 if (Shape::Equal()(output_shape, Shape(operand->shape()))) {
1679 // Forward operand handle as result.
1680 return SimplifyOp(inst->operand_ids(0));
1681 } else {
1682 return CreateS64Literal(-1, output_shape);
1683 }
1684 }
1685 case HloOpcode::kAdd: {
1686 // a + (b - a) => b
1687 // a + b + (c - a) => b + c
1688 if (output_shape.rank() == 0) {
1689 TF_ASSIGN_OR_RETURN(auto lhs, SimplifyOp(inst->operand_ids(0)));
1690 TF_ASSIGN_OR_RETURN(auto rhs, SimplifyOp(inst->operand_ids(1)));
1691 int64_t lhs_handle = lhs.Get<int64_t>({});
1692 int64_t rhs_handle = rhs.Get<int64_t>({});
1693 if (lhs_handle == -1 || rhs_handle == -1) {
1694 return CreateS64Literal(-1, output_shape);
1695 }
1696 // Recursive lambda needs explicit signature.
1697 std::function<std::optional<int64_t>(int64_t, int64_t)>
1698 can_be_optimized;
1699 can_be_optimized = [this, &can_be_optimized](
1700 int64_t lhs,
1701 int64_t rhs) -> std::optional<int64_t> {
1702 auto rhs_inst = builder_->LookUpInstructionByHandle(rhs).ValueOrDie();
1703 HloOpcode rhs_opcode =
1704 StringToHloOpcode(rhs_inst->opcode()).ValueOrDie();
1705 if (rhs_opcode == HloOpcode::kSubtract) {
1706 auto sub_lhs_handle = SimplifyOp(rhs_inst->operand_ids(0))
1707 .ValueOrDie()
1708 .Get<int64_t>({});
1709 auto sub_rhs_handle = SimplifyOp(rhs_inst->operand_ids(1))
1710 .ValueOrDie()
1711 .Get<int64_t>({});
1712 if (sub_rhs_handle == lhs) {
1713 // lhs + (sub_lhs - sub_rhs) = sub_lhs if lhs == sub_rhs
1714 return sub_lhs_handle;
1715 }
1716 }
1717
1718 // Check the case for a + b + (c - a) => b + c
1719 auto lhs_inst = builder_->LookUpInstructionByHandle(lhs).ValueOrDie();
1720 HloOpcode lhs_opcode =
1721 StringToHloOpcode(lhs_inst->opcode()).ValueOrDie();
1722 if (lhs_opcode == HloOpcode::kAdd) {
1723 auto add_lhs_handle = SimplifyOp(lhs_inst->operand_ids(0))
1724 .ValueOrDie()
1725 .Get<int64_t>({});
1726 auto add_rhs_handle = SimplifyOp(lhs_inst->operand_ids(1))
1727 .ValueOrDie()
1728 .Get<int64_t>({});
1729 if (auto optimized = can_be_optimized(add_lhs_handle, rhs)) {
1730 return Add(XlaOp(add_rhs_handle, builder_),
1731 XlaOp(optimized.value(), builder_))
1732 .handle();
1733 }
1734 if (auto optimized = can_be_optimized(add_rhs_handle, rhs)) {
1735 return Add(XlaOp(add_lhs_handle, builder_),
1736 XlaOp(optimized.value(), builder_))
1737 .handle();
1738 }
1739 }
1740 return std::nullopt;
1741 };
1742 if (auto optimized = can_be_optimized(lhs_handle, rhs_handle)) {
1743 return LiteralUtil::CreateR0<int64_t>(optimized.value());
1744 }
1745 // Swap lhs and rhs.
1746 if (auto optimized = can_be_optimized(rhs_handle, lhs_handle)) {
1747 return LiteralUtil::CreateR0<int64_t>(optimized.value());
1748 }
1749 // This sum can't be optimized, return sum of lhs and rhs. Note that we
1750 // can't just return the original sum as its lhs and rhs could be
1751 // optimized and different.
1752 XlaOp new_sum =
1753 Add(XlaOp(lhs_handle, builder_), XlaOp(rhs_handle, builder_));
1754
1755 return LiteralUtil::CreateR0<int64_t>(new_sum.handle());
1756 } else {
1757 return CreateS64Literal(-1, output_shape);
1758 }
1759 }
1760 default: {
1761 if (ShapeUtil::IsScalar(output_shape)) {
1762 return LiteralUtil::CreateR0<int64_t>(handle);
1763 } else {
1764 return CreateS64Literal(-1, output_shape);
1765 }
1766 }
1767 }
1768 }
1769
AnalyzeConstant(XlaOp op,ValueInferenceMode mode)1770 StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
1771 XlaOp op, ValueInferenceMode mode) {
1772 TF_RETURN_IF_ERROR(builder_->LookUpInstructionByHandle(op.handle()).status());
1773 PostorderDFSVisitor visitor(
1774 evaluator_,
1775 [&](int64_t handle) {
1776 return builder_->LookUpInstructionByHandle(handle);
1777 },
1778 [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1779 TF_ASSIGN_OR_RETURN(Shape op_shape, builder_->GetShape(op));
1780 int64_t handle = op.handle();
1781 if (ShapeUtil::IsScalar(builder_->GetShape(op).ValueOrDie())) {
1782 TF_ASSIGN_OR_RETURN(auto result, SimplifyOp(handle));
1783 auto optimized_handle = result.Get<int64_t>({});
1784 if (optimized_handle != -1) {
1785 handle = optimized_handle;
1786 }
1787 }
1788 switch (mode) {
1789 case ValueInferenceMode::kLowerBound: {
1790 TF_ASSIGN_OR_RETURN(Literal mask,
1791 visitor.PostOrderDFSVisit(
1792 handle, PostorderDFSNodeType::kBoundIsDynamic));
1793 if (mask.IsAll(1)) {
1794 // Everything is dynamic, no need to do constant inference.
1795 return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
1796 }
1797 TF_ASSIGN_OR_RETURN(
1798 Literal value,
1799 visitor.PostOrderDFSVisit(handle,
1800 PostorderDFSNodeType::kConstantLowerBound));
1801
1802 return OptionalLiteral(std::move(value), std::move(mask));
1803 }
1804 case ValueInferenceMode::kUpperBound: {
1805 TF_ASSIGN_OR_RETURN(Literal mask,
1806 visitor.PostOrderDFSVisit(
1807 handle, PostorderDFSNodeType::kBoundIsDynamic));
1808 if (mask.IsAll(1)) {
1809 // Everything is dynamic, no need to do constant inference.
1810 return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
1811 }
1812 TF_ASSIGN_OR_RETURN(
1813 Literal value,
1814 visitor.PostOrderDFSVisit(handle,
1815 PostorderDFSNodeType::kConstantUpperBound));
1816
1817 return OptionalLiteral(std::move(value), std::move(mask));
1818 }
1819 case ValueInferenceMode::kValue: {
1820 TF_ASSIGN_OR_RETURN(Literal mask,
1821 visitor.PostOrderDFSVisit(
1822 handle, PostorderDFSNodeType::kValueIsDynamic));
1823 if (mask.IsAll(1)) {
1824 // Everything is dynamic, no need to do constant inference.
1825 return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
1826 }
1827 TF_ASSIGN_OR_RETURN(Literal value,
1828 visitor.PostOrderDFSVisit(
1829 handle, PostorderDFSNodeType::kConstantValue));
1830
1831 return OptionalLiteral(std::move(value), std::move(mask));
1832 }
1833 }
1834 }
1835
1836 } // namespace xla
1837