• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/value_inference.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
31 #include "tensorflow/compiler/xla/service/hlo.pb.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/stream_executor/lib/statusor.h"
41 
42 namespace xla {
43 namespace {
CreatePredLiteral(bool pred,const Shape & reference_shape)44 Literal CreatePredLiteral(bool pred, const Shape& reference_shape) {
45   if (reference_shape.IsTuple()) {
46     std::vector<Literal> sub_literals;
47     const auto& reference_shape_tuple_shapes = reference_shape.tuple_shapes();
48     sub_literals.reserve(reference_shape_tuple_shapes.size());
49     for (const Shape& shape : reference_shape_tuple_shapes) {
50       sub_literals.emplace_back(CreatePredLiteral(pred, shape));
51     }
52     return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
53   }
54   PrimitiveType element_type = reference_shape.element_type();
55   if (element_type == TOKEN) {
56     return LiteralUtil::CreateR0(pred);
57   }
58   Literal literal = LiteralUtil::CreateR0(pred);
59   Literal literal_broadcast =
60       literal.Broadcast(ShapeUtil::ChangeElementType(reference_shape, PRED), {})
61           .ValueOrDie();
62   return literal_broadcast;
63 }
64 
CreateS64Literal(int64_t value,const Shape & reference_shape)65 Literal CreateS64Literal(int64_t value, const Shape& reference_shape) {
66   if (reference_shape.IsTuple()) {
67     std::vector<Literal> sub_literals;
68     const auto& reference_shape_tuple_shapes = reference_shape.tuple_shapes();
69     sub_literals.reserve(reference_shape_tuple_shapes.size());
70     for (const Shape& shape : reference_shape_tuple_shapes) {
71       sub_literals.emplace_back(CreateS64Literal(value, shape));
72     }
73     return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
74   }
75   PrimitiveType element_type = reference_shape.element_type();
76   if (element_type == TOKEN) {
77     return LiteralUtil::CreateToken();
78   }
79   Literal literal = LiteralUtil::CreateR0<int64_t>(value);
80   return literal
81       .Broadcast(ShapeUtil::ChangeElementType(reference_shape, S64), {})
82       .ValueOrDie();
83 }
84 
85 // Create a literal with garbage data. The data inside is undefined and
86 // shouldn't be used in any meaningful computation.
CreateGarbageLiteral(const Shape & reference_shape)87 Literal CreateGarbageLiteral(const Shape& reference_shape) {
88   if (reference_shape.IsTuple()) {
89     std::vector<Literal> sub_literals;
90     for (const Shape& shape : reference_shape.tuple_shapes()) {
91       sub_literals.emplace_back(CreateGarbageLiteral(shape));
92     }
93     return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
94   }
95   PrimitiveType element_type = reference_shape.element_type();
96   if (element_type == TOKEN) {
97     return LiteralUtil::CreateToken();
98   }
99   Literal literal = LiteralUtil::One(element_type);
100   return literal.Broadcast(reference_shape, {}).ValueOrDie();
101 }
102 
103 // HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
104 // to provide operand as literals through the get_operand function.
105 struct HloProtoEvaluator {
HloProtoEvaluatorxla::__anon1970efe50111::HloProtoEvaluator106   explicit HloProtoEvaluator(HloEvaluator& evaluator, HloInstructionProto inst)
107       : evaluator(evaluator),
108         inst(std::move(inst)),
109         module("EmptyModuleForEvaluation", HloModuleConfig()) {}
110 
111   // WithOpCode changes the called computation of the instruction being
112   // evaluated.
WithComputationxla::__anon1970efe50111::HloProtoEvaluator113   HloProtoEvaluator& WithComputation(
114       std::unique_ptr<HloComputation> new_computation) {
115     computation = new_computation.get();
116     computation->ClearUniqueIdInternal();
117     for (HloInstruction* inst : computation->instructions()) {
118       inst->ClearUniqueIdInternal();
119     }
120     module.AddEmbeddedComputation(std::move(new_computation));
121     return *this;
122   }
123 
124   // WithPrimitiveType changes the primitive type of the instruction being
125   // evaluated.
WithPrimitiveTypexla::__anon1970efe50111::HloProtoEvaluator126   HloProtoEvaluator& WithPrimitiveType(PrimitiveType new_primitive_type) {
127     primitive_type = new_primitive_type;
128     return *this;
129   }
130 
131   // WithOpCode changes the opcode of the instruction being evaluated.
WithOpCodexla::__anon1970efe50111::HloProtoEvaluator132   HloProtoEvaluator& WithOpCode(HloOpcode new_opcode) {
133     opcode = new_opcode;
134     return *this;
135   }
136 
137   // WithOperands changes the operands of the instruction being evaluated.
WithOperandsxla::__anon1970efe50111::HloProtoEvaluator138   HloProtoEvaluator& WithOperands(absl::Span<Literal> operands) {
139     this->operands = operands;
140     return *this;
141   }
142 
143   // When WithSubshape is set, the result tuple shape will be decomposed and
144   // specific the literal will be returned.
WithSubshapexla::__anon1970efe50111::HloProtoEvaluator145   HloProtoEvaluator& WithSubshape(ShapeIndex shape_index) {
146     this->shape_index = std::move(shape_index);
147     return *this;
148   }
149 
Evaluatexla::__anon1970efe50111::HloProtoEvaluator150   StatusOr<Literal> Evaluate() {
151     // Evaluate the instruction by swapping it's operands with constant
152     // instructions with given literals.
153     HloComputation::Builder builder("EmptyComputation");
154     absl::flat_hash_map<int64_t, HloInstruction*> operand_map;
155     for (int64_t i = 0; i < inst.operand_ids_size(); ++i) {
156       int64_t operand_handle = inst.operand_ids(i);
157       std::unique_ptr<HloInstruction> operand =
158           HloInstruction::CreateConstant(operands[i].Clone());
159       operand_map[operand_handle] = operand.get();
160       builder.AddInstruction(std::move(operand));
161     }
162 
163     if (primitive_type.has_value()) {
164       *inst.mutable_shape() = ShapeUtil::ChangeElementType(
165                                   Shape(inst.shape()), primitive_type.value())
166                                   .ToProto();
167     }
168     if (opcode.has_value()) {
169       *inst.mutable_opcode() = HloOpcodeString(opcode.value());
170     }
171     absl::flat_hash_map<int64_t, HloComputation*> computation_map;
172     if (inst.called_computation_ids_size() != 0) {
173       TF_RET_CHECK(inst.called_computation_ids_size() == 1 &&
174                    computation != nullptr)
175           << inst.DebugString();
176       computation_map[inst.called_computation_ids(0)] = computation;
177     }
178     TF_ASSIGN_OR_RETURN(
179         auto new_instruction,
180         HloInstruction::CreateFromProto(inst, operand_map, computation_map));
181     new_instruction->ClearUniqueIdInternal();
182     builder.AddInstruction(std::move(new_instruction));
183     auto computation = builder.Build();
184     module.AddEntryComputation(std::move(computation));
185     if (shape_index.empty()) {
186       return evaluator.Evaluate(module.entry_computation()->root_instruction());
187     } else {
188       TF_ASSIGN_OR_RETURN(
189           auto result,
190           evaluator.Evaluate(module.entry_computation()->root_instruction()));
191       return result.SubLiteral(this->shape_index);
192     }
193   }
194 
195   HloEvaluator& evaluator;
196   HloInstructionProto inst;
197 
198   HloModule module;
199   absl::Span<Literal> operands;
200   ShapeIndex shape_index = {};
201   HloComputation* computation = nullptr;
202   std::optional<PrimitiveType> primitive_type = std::nullopt;
203   std::optional<HloOpcode> opcode = std::nullopt;
204 };
205 
206 enum PostorderDFSNodeType {
207   // This node is about figuring out the constant value.
208   kConstantValue = 0,
209   // This node is about figuring out the constant bound.
210   kConstantUpperBound,
211   kConstantLowerBound,
212   // This node is about figuring out whether a value is dynamic.
213   kValueIsDynamic,
214   // This node is about figuring out whether a bound value is dynamic. It's
215   // similar to kValueIsDynamic, but views shape bound as static values.
216   kBoundIsDynamic,
217 };
218 
PostorderDFSNodeTypeToString(PostorderDFSNodeType type)219 std::string PostorderDFSNodeTypeToString(PostorderDFSNodeType type) {
220   switch (type) {
221     case kConstantValue:
222       return "kConstantValue";
223     case kConstantUpperBound:
224       return "kConstantUpperBound";
225     case kConstantLowerBound:
226       return "kConstantLowerBound";
227     case kValueIsDynamic:
228       return "kValueIsDynamic";
229     case kBoundIsDynamic:
230       return "kBoundIsDynamic";
231   }
232 }
233 
234 struct InferenceContext {
InferenceContextxla::__anon1970efe50111::InferenceContext235   explicit InferenceContext(ShapeIndex shape_index,
236                             std::vector<int64_t> caller_operand_handles)
237       : shape_index(std::move(shape_index)),
238         caller_operand_handles(std::move(caller_operand_handles)) {}
239   // `shape_index` represents the subshape that we care about in the inference.
240   // It is used to avoid meterializing the whole tuple when we only care about a
241   // sub tensor of it.
242   ShapeIndex shape_index;
243 
244   // caller_operand_handles is a stack that helps argument forwarding. The top
245   // of the stack represents the tensor to be forwarded to the
246   // parameter of the inner most function. E.g.,:
247   // inner_true_computation {
248   //   inner_p0 = param()
249   //   ...
250   // }
251   //
252   // true_computaion {
253   //   p0 = param()
254   //   conditional(pred, p0, inner_true_computation,
255   //                     ...)
256   // }
257   //
258   // main {
259   //   op = ..
260   //   conditional(pred, op, true_computation, ...)
261   // }
262   //
263   // In this case, when we analyze inner_true_computation, the
264   // `caller_operand_handlers` will be [op, p0] -- p0 is what should be
265   // forwarded to inner_p0 and op is what should be forwarded to p0. similarly,
266   // when we analyze true_computation, the `caller_operand_handlers` will be
267   // [op].
268   std::vector<int64_t> caller_operand_handles;
269 };
270 
271 // Each node in the postorder traversal tree may depend on traversing the
272 // values of the node's children.
273 struct PostorderDFSDep {
PostorderDFSDepxla::__anon1970efe50111::PostorderDFSDep274   explicit PostorderDFSDep(int64_t handle, PostorderDFSNodeType type,
275                            InferenceContext context, std::string annotation)
276       : handle(handle),
277         type(type),
278         context(std::move(context)),
279         annotation(std::move(annotation)) {}
280   int64_t handle;
281   PostorderDFSNodeType type;
282   InferenceContext context;
283   std::string annotation;
284 };
285 
286 // This function represents the logic to visit a node once its dependencies
287 // (operands) are all resolved.
288 using Visit = std::function<StatusOr<Literal>(absl::Span<Literal>)>;
289 // Convenient specializations of Visit function for different operands.
290 using Visit0D = std::function<StatusOr<Literal>()>;
291 using Visit1D = std::function<StatusOr<Literal>(Literal)>;
292 using Visit2D = std::function<StatusOr<Literal>(Literal, Literal)>;
293 
294 // A postorder dfs node can be visited once its dependency requests are all
295 // fulfilled.
296 struct [[nodiscard]] PostorderDFSNode {
AddDependencyxla::__anon1970efe50111::PostorderDFSNode297   PostorderDFSNode& AddDependency(int64_t handle, PostorderDFSNodeType type,
298                                   InferenceContext context,
299                                   std::string annotation = "") {
300     dependencies.emplace_back(handle, type, std::move(context),
301                               std::move(annotation));
302     return *this;
303   }
304 
AddVisitxla::__anon1970efe50111::PostorderDFSNode305   PostorderDFSNode& AddVisit(const Visit& visit) {
306     this->visit = visit;
307     return *this;
308   }
309 
AddVisitxla::__anon1970efe50111::PostorderDFSNode310   PostorderDFSNode& AddVisit(const Visit0D& visit) {
311     this->visit = [visit](absl::Span<Literal> literals) { return visit(); };
312     return *this;
313   }
314 
AddVisitxla::__anon1970efe50111::PostorderDFSNode315   PostorderDFSNode& AddVisit(const Visit1D& visit) {
316     this->visit = [visit](absl::Span<Literal> literals) {
317       return visit(std::move(literals[0]));
318     };
319     return *this;
320   }
321 
AddVisitxla::__anon1970efe50111::PostorderDFSNode322   PostorderDFSNode& AddVisit(const Visit2D& visit) {
323     this->visit = [visit](absl::Span<Literal> literals) {
324       return visit(std::move(literals[0]), std::move(literals[1]));
325     };
326     return *this;
327   }
328   std::vector<PostorderDFSDep> dependencies;
329   Visit visit;
330 };
331 
332 // Convert an interger handle to HloInstructionProto.
333 using HandleToInstruction =
334     std::function<StatusOr<const HloInstructionProto*>(int64_t)>;
335 using HandleToComputation = std::function<const HloComputationProto*(int64_t)>;
336 
337 struct PostorderDFSVisitor {
PostorderDFSVisitorxla::__anon1970efe50111::PostorderDFSVisitor338   PostorderDFSVisitor(HloEvaluator& evaluator,
339                       HandleToInstruction handle_to_instruction,
340                       HandleToComputation handle_to_computation)
341       : evaluator(evaluator),
342         handle_to_instruction(handle_to_instruction),
343         handle_to_computation(handle_to_computation) {}
344 
345   StatusOr<PostorderDFSNode> AnalyzeUpperBound(int64_t handle,
346                                                InferenceContext context);
347   StatusOr<PostorderDFSNode> AnalyzeLowerBound(int64_t handle,
348                                                InferenceContext context);
349   StatusOr<PostorderDFSNode> AnalyzeIsDynamic(int64_t handle,
350                                               PostorderDFSNodeType type,
351                                               InferenceContext context);
352   StatusOr<PostorderDFSNode> AnalyzeConstant(int64_t handle,
353                                              InferenceContext context);
354   StatusOr<PostorderDFSNode> AnalyzeConstantValueFallback(
355       int64_t handle, PostorderDFSNodeType type, InferenceContext context);
356 
357   StatusOr<Literal> PostOrderDFSVisit(int64_t handle,
358                                       PostorderDFSNodeType type);
359 
360   // Returns true if a value represented by `handle` is an integeral type or
361   // a floating pointer type that just got converted from an integral type.
362   // E.g.,:
363   // int(a) -> true
364   // float(int(a)) -> true
365   // float(a) -> false -- We don't know the concrete value of `a` at
366   // compile time, except for its type.
IsValueEffectiveIntegerxla::__anon1970efe50111::PostorderDFSVisitor367   bool IsValueEffectiveInteger(int64_t handle) {
368     // handle_to_instruction's failure status should be checked by parent.
369     const HloInstructionProto* instr =
370         handle_to_instruction(handle).ValueOrDie();
371     if (primitive_util::IsIntegralType(instr->shape().element_type())) {
372       return true;
373     }
374     // Also returns true if this is a convert that converts an integer to float.
375     HloOpcode opcode = StringToHloOpcode(instr->opcode()).ValueOrDie();
376     if (opcode != HloOpcode::kConvert) {
377       return false;
378     }
379     const HloInstructionProto* parent =
380         handle_to_instruction(instr->operand_ids(0)).ValueOrDie();
381     if (primitive_util::IsIntegralType(parent->shape().element_type())) {
382       return true;
383     }
384     return false;
385   }
386 
387   // Checks the size of outputs and inputs. Returns true if any of them has size
388   // beyond kLargeShapeElementLimit and the instruction needs evaluation (e.g.,
389   // kGetDimensionSize or kSetDimensionSize doesn't need evaluation).
IsInstructionOverLimitxla::__anon1970efe50111::PostorderDFSVisitor390   bool IsInstructionOverLimit(const HloInstructionProto* proto,
391                               InferenceContext context) {
392     Shape subshape =
393         ShapeUtil::GetSubshape(Shape(proto->shape()), context.shape_index);
394 
395     if (subshape.IsArray() &&
396         ShapeUtil::ElementsIn(subshape) > kLargeShapeElementLimit) {
397       return true;
398     }
399     HloOpcode opcode = StringToHloOpcode(proto->opcode()).ValueOrDie();
400     for (int64_t operand_id : proto->operand_ids()) {
401       const HloInstructionProto* operand =
402           handle_to_instruction(operand_id).ValueOrDie();
403       Shape operand_shape = Shape(operand->shape());
404 
405       if (operand_shape.IsArray() &&
406           ShapeUtil::ElementsIn(operand_shape) > kLargeShapeElementLimit &&
407           opcode != HloOpcode::kGetDimensionSize &&
408           opcode != HloOpcode::kSetDimensionSize) {
409         return true;
410       }
411     }
412     return false;
413   }
414 
415   struct CacheKey {
CacheKeyxla::__anon1970efe50111::PostorderDFSVisitor::CacheKey416     CacheKey(int64_t handle, InferenceContext context,
417              PostorderDFSNodeType type)
418         : handle(handle), context(context), type(type) {}
419     int64_t handle;
420     InferenceContext context;
421     PostorderDFSNodeType type;
422 
423     template <typename H>
AbslHashValuexla::__anon1970efe50111::PostorderDFSVisitor424     friend H AbslHashValue(H h, const CacheKey& key) {
425       h = H::combine(std::move(h), key.handle);
426       h = H::combine(std::move(h), key.context.shape_index.ToString());
427       h = H::combine(std::move(h),
428                      VectorString(key.context.caller_operand_handles));
429       h = H::combine(std::move(h), key.type);
430       return h;
431     }
432 
operator ==xla::__anon1970efe50111::PostorderDFSVisitor433     friend bool operator==(const CacheKey& lhs, const CacheKey& rhs) {
434       return lhs.handle == rhs.handle &&
435              lhs.context.shape_index == rhs.context.shape_index &&
436              lhs.context.caller_operand_handles ==
437                  rhs.context.caller_operand_handles &&
438              lhs.type == rhs.type;
439     }
440   };
441 
442   HloEvaluator& evaluator;
443   absl::flat_hash_map<CacheKey, Literal> evaluated;
444   HandleToInstruction handle_to_instruction;
445   HandleToComputation handle_to_computation;
446   // Give up when dealing with more than 1M elements.
447   static constexpr int64_t kLargeShapeElementLimit = 1000 * 1000;
448 };
449 
450 // Returns a result representing that value is fully dynamic and can't be
451 // inferred. In other words, "give up" and return most conservative value.
CreateAllDynamicResult(Shape shape,PostorderDFSNodeType type)452 PostorderDFSNode CreateAllDynamicResult(Shape shape,
453                                         PostorderDFSNodeType type) {
454   return PostorderDFSNode().AddVisit(
455       [shape, type](absl::Span<Literal>) -> Literal {
456         if (type == PostorderDFSNodeType::kConstantValue ||
457             type == PostorderDFSNodeType::kConstantUpperBound ||
458             type == PostorderDFSNodeType::kConstantLowerBound) {
459           // When inferencing constant values, create garbage data, which will
460           // be masked out by dynamism counterpart.
461           return CreateGarbageLiteral(shape);
462         } else {
463           // When dynamism, return true, indicating all values are dynamic.
464           return CreatePredLiteral(true, shape);
465         }
466       });
467 }
468 
469 }  // namespace
470 
471 // Analyze a tensor's constant value, upper-bound value or lower-bound value.
AnalyzeConstantValueFallback(int64_t handle,PostorderDFSNodeType type,InferenceContext context)472 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstantValueFallback(
473     int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
474   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
475                       handle_to_instruction(handle));
476   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
477   Shape subshape =
478       ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
479   PostorderDFSNode result;
480   // By default, the dependencies of current node are its operands.
481   for (auto operand_id : root->operand_ids()) {
482     InferenceContext dep_context = context;
483     dep_context.shape_index = {};
484     result.AddDependency(operand_id, type, dep_context);
485   }
486   switch (opcode) {
487       // Non functional ops.
488     case HloOpcode::kRng:
489     case HloOpcode::kAllReduce:
490     case HloOpcode::kReduceScatter:
491     case HloOpcode::kInfeed:
492     case HloOpcode::kOutfeed:
493     case HloOpcode::kRngBitGenerator:
494     case HloOpcode::kCustomCall:
495     case HloOpcode::kWhile:
496     case HloOpcode::kSend:
497     case HloOpcode::kRecv:
498     case HloOpcode::kSendDone:
499     case HloOpcode::kRecvDone:
500     case HloOpcode::kParameter: {
501       if (opcode == HloOpcode::kParameter &&
502           !context.caller_operand_handles.empty()) {
503         int64_t caller_operand = context.caller_operand_handles.back();
504         context.caller_operand_handles.pop_back();
505         return result.AddDependency(caller_operand, type, context)
506             .AddVisit([](Literal literal) { return literal; });
507       }
508       return CreateAllDynamicResult(subshape, type);
509     }
510     // Subtract and Divide use lower-bound as second operand.
511     case HloOpcode::kSubtract:
512     case HloOpcode::kCos:
513     case HloOpcode::kSin:
514     case HloOpcode::kNegate:
515     case HloOpcode::kAbs:
516     case HloOpcode::kDivide:
517     case HloOpcode::kGetDimensionSize: {
518       return InvalidArgument(
519           "AnalyzeConstantValueFallback can't handle opcode: %s",
520           root->opcode());
521     }
522     case HloOpcode::kCall: {
523       auto node = PostorderDFSNode();
524       auto* call_proto = root;
525       if (call_proto->operand_ids_size() != 1) {
526         // Only support single operand forwarding.
527         return CreateAllDynamicResult(subshape, type);
528       }
529       int64_t called_root =
530           handle_to_computation(call_proto->called_computation_ids(0))
531               ->root_id();
532       InferenceContext call_context = context;
533       call_context.caller_operand_handles.push_back(call_proto->operand_ids(0));
534       node.AddDependency(called_root, PostorderDFSNodeType::kConstantValue,
535                          call_context, "callee's root instruction");
536       return node.AddVisit([](Literal operand) -> StatusOr<Literal> {
537         // Forward result of callee's root to caller.
538         return std::move(operand);
539       });
540     }
541 
542     case HloOpcode::kConditional: {
543       auto node = PostorderDFSNode();
544       auto* conditional_proto = root;
545       InferenceContext predicate_context = context;
546       predicate_context.shape_index = {};
547       // Add dependencies to analyze the predicate of the conditional.
548       node.AddDependency(conditional_proto->operand_ids(0),
549                          PostorderDFSNodeType::kConstantValue,
550                          predicate_context)
551           .AddDependency(conditional_proto->operand_ids(0),
552                          PostorderDFSNodeType::kValueIsDynamic,
553                          predicate_context);
554       const int64_t branch_size =
555           conditional_proto->called_computation_ids_size();
556       for (int64_t i = 0; i < branch_size; ++i) {
557         int64_t branch_root =
558             handle_to_computation(conditional_proto->called_computation_ids(i))
559                 ->root_id();
560         InferenceContext branch_context = context;
561         branch_context.caller_operand_handles.push_back(
562             conditional_proto->operand_ids(i + 1));
563         node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
564                            branch_context);
565       }
566       return node.AddVisit(
567           [](absl::Span<Literal> operands) -> StatusOr<Literal> {
568             int64_t pred_is_dynamic = operands[1].Get<bool>({});
569             if (pred_is_dynamic) {
570               // If predicate is dynamic, return the value of the first branch
571               // -- If all branches return the same value, this is the value
572               // that we want; If not, the value will be masked anyway so the
573               // value inside doesn't matter.
574               return std::move(operands[2]);
575             } else {
576               // If predicate is static, return the value of the given branch.
577               int64_t branch_index = 0;
578               if (operands[0].shape().element_type() == PRED) {
579                 if (operands[0].Get<bool>({})) {
580                   branch_index = 0;
581                 } else {
582                   branch_index = 1;
583                 }
584               } else {
585                 branch_index = operands[0].GetIntegralAsS64({}).value();
586               }
587               const int64_t branch_dynamism_index = 2 + branch_index;
588               return std::move(operands[branch_dynamism_index]);
589             }
590           });
591     }
592     case HloOpcode::kGetTupleElement: {
593       int64_t operand_handle = root->operand_ids(0);
594       PostorderDFSNode result;
595       context.shape_index.push_front(root->tuple_index());
596       return PostorderDFSNode()
597           .AddDependency(operand_handle, type, context)
598           .AddVisit([](Literal operand) { return operand; });
599     }
600     case HloOpcode::kReduce:
601     case HloOpcode::kSort:
602     case HloOpcode::kScatter:
603     case HloOpcode::kReduceWindow: {
604       const HloComputationProto* computation_proto =
605           handle_to_computation(root->called_computation_ids(0));
606       return result.AddVisit(
607           [root, computation_proto, context,
608            this](absl::Span<Literal> operands) -> StatusOr<Literal> {
609             TF_ASSIGN_OR_RETURN(
610                 auto computation,
611                 HloComputation::CreateFromProto(*computation_proto, {}));
612             return HloProtoEvaluator(evaluator, *root)
613                 .WithOperands(operands)
614                 .WithComputation(std::move(computation))
615                 .WithSubshape(context.shape_index)
616                 .Evaluate();
617           });
618     }
619     default: {
620       if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
621         // There could be many operands of a tuple, but only one that we are
622         // interested in, represented by `tuple_operand_index`.
623         int64_t tuple_operand_index = context.shape_index.front();
624         InferenceContext tuple_operand_context = context;
625         tuple_operand_context.shape_index.pop_front();
626         return PostorderDFSNode()
627             .AddDependency(root->operand_ids(tuple_operand_index), type,
628                            tuple_operand_context)
629             .AddVisit([](Literal operand) { return operand; });
630       }
631       return result.AddVisit([root, this](absl::Span<Literal> operands) {
632         return HloProtoEvaluator(evaluator, *root)
633             .WithOperands(operands)
634             .Evaluate();
635       });
636     }
637   }
638 }
639 
AnalyzeUpperBound(int64_t handle,InferenceContext context)640 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeUpperBound(
641     int64_t handle, InferenceContext context) {
642   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
643                       handle_to_instruction(handle));
644   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
645   Shape subshape =
646       ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
647 
648   if (IsInstructionOverLimit(root, context)) {
649     return CreateAllDynamicResult(subshape,
650                                   PostorderDFSNodeType::kConstantUpperBound);
651   }
652   switch (opcode) {
653     case HloOpcode::kGetDimensionSize: {
654       int64_t dimension = root->dimensions(0);
655       int64_t operand_handle = root->operand_ids(0);
656       const HloInstructionProto* operand_proto =
657           handle_to_instruction(operand_handle).ValueOrDie();
658       return PostorderDFSNode().AddVisit(
659           [operand_proto, dimension]() -> StatusOr<Literal> {
660             return LiteralUtil::CreateR0<int32_t>(
661                 operand_proto->shape().dimensions(dimension));
662           });
663     }
664     case HloOpcode::kAbs: {
665       // upper-bound(abs(operand)) = max(abs(lower-bound(operand)),
666       //                                 abs(upper-bound(operand)))
667       return PostorderDFSNode()
668           .AddDependency(root->operand_ids(0),
669                          PostorderDFSNodeType::kConstantLowerBound, context)
670           .AddDependency(root->operand_ids(0),
671                          PostorderDFSNodeType::kConstantUpperBound, context)
672           .AddVisit([this](Literal lower_bound,
673                            Literal upper_bound) -> StatusOr<Literal> {
674             TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
675                                 evaluator.EvaluateElementwiseUnaryOp(
676                                     HloOpcode::kAbs, lower_bound));
677             TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
678                                 evaluator.EvaluateElementwiseUnaryOp(
679                                     HloOpcode::kAbs, upper_bound));
680             return evaluator.EvaluateElementwiseBinaryOp(
681                 HloOpcode::kMaximum, lower_bound_abs, upper_bound_abs);
682           });
683     }
684     case HloOpcode::kSort: {
685       auto dfs = PostorderDFSNode();
686       InferenceContext dep_context = context;
687       dep_context.shape_index = {};
688       if (!context.shape_index.empty()) {
689         // Lazy evaluation: Only need to evaluate a subelement in a
690         // variadic-sort tensor.
691         dfs.AddDependency(root->operand_ids(context.shape_index[0]),
692                           PostorderDFSNodeType::kConstantUpperBound,
693                           dep_context);
694       } else {
695         for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
696           dfs.AddDependency(root->operand_ids(i),
697                             PostorderDFSNodeType::kConstantUpperBound,
698                             dep_context);
699         }
700       }
701 
702       return dfs.AddVisit(
703           [root, context](absl::Span<Literal> operands) -> StatusOr<Literal> {
704             std::vector<Literal> results;
705             results.reserve(operands.size());
706             // Conservatively set each element of the tensor to the max value.
707             for (int64_t i = 0; i < operands.size(); ++i) {
708               auto max = LiteralUtil::MaxElement(operands[i]);
709               results.emplace_back(
710                   max.Broadcast(operands[i].shape(), {}).ValueOrDie());
711             }
712             if (ShapeUtil::GetSubshape(Shape(root->shape()),
713                                        context.shape_index)
714                     .IsTuple()) {
715               return LiteralUtil::MakeTupleOwned(std::move(results));
716             } else {
717               return std::move(results[0]);
718             }
719           });
720     }
721     case HloOpcode::kNegate: {
722       // upper-bound(negate(operand)) = negate(lower-bound(operand))
723       return PostorderDFSNode()
724           .AddDependency(root->operand_ids(0),
725                          PostorderDFSNodeType::kConstantLowerBound, context)
726           .AddVisit([this](Literal lower_bound) -> StatusOr<Literal> {
727             return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
728                                                         lower_bound);
729           });
730     }
731     case HloOpcode::kSubtract:
732     case HloOpcode::kDivide: {
733       // Lower-bound is used for second operand of subtract and divide.
734       return PostorderDFSNode()
735           .AddDependency(root->operand_ids(0),
736                          PostorderDFSNodeType::kConstantUpperBound, context)
737           .AddDependency(root->operand_ids(1),
738                          PostorderDFSNodeType::kConstantLowerBound, context)
739           .AddVisit([root, opcode, this](
740                         Literal upper_bound,
741                         Literal lower_bound) -> StatusOr<Literal> {
742             if (opcode == HloOpcode::kDivide &&
743                 this->IsValueEffectiveInteger(root->operand_ids(1))) {
744               // Because in many cases the lower bound of a value is
745               // integer 0, instead of throwing an divide-by-zero error
746               // at compile time, we set the bound defer the check to
747               // runtime. In those cases we use the upper-bound of
748               // first operand as a placeholder.
749               auto zero = LiteralUtil::Zero(lower_bound.shape().element_type());
750               zero = zero.Broadcast(lower_bound.shape(), {}).ValueOrDie();
751               TF_ASSIGN_OR_RETURN(
752                   auto lower_bound_is_zero,
753                   evaluator.EvaluateElementwiseCompareOp(
754                       ComparisonDirection::kEq, lower_bound, zero));
755 
756               auto one = LiteralUtil::One(lower_bound.shape().element_type());
757               one = one.Broadcast(lower_bound.shape(), {}).ValueOrDie();
758               TF_ASSIGN_OR_RETURN(
759                   lower_bound, evaluator.EvaluateElementwiseTernaryOp(
760                                    HloOpcode::kSelect, lower_bound_is_zero, one,
761                                    lower_bound));
762             }
763             std::vector<Literal> new_operands;
764             new_operands.emplace_back(std::move(upper_bound));
765             new_operands.emplace_back(std::move(lower_bound));
766             return HloProtoEvaluator(evaluator, *root)
767                 .WithOperands(absl::MakeSpan(new_operands))
768                 .Evaluate();
769           });
770     }
771     case HloOpcode::kCustomCall: {
772       if (root->custom_call_target() == "SetBound") {
773         return PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
774           if (root->literal().shape().element_type() == TUPLE) {
775             // First literal of SetBound contains bounds, second literal
776             // contains dynamism indicators.
777             return Literal::CreateFromProto(root->literal().tuple_literals(0));
778           } else {
779             return Literal::CreateFromProto(root->literal());
780           }
781         });
782       } else if (root->custom_call_target() == "Sharding") {
783         return PostorderDFSNode()
784             .AddDependency(root->operand_ids(0),
785                            PostorderDFSNodeType::kConstantUpperBound, context)
786             .AddVisit([](Literal operand) { return operand; });
787       }
788       return InvalidArgument(
789           "Upper-bound inferencing on custom call %s is not supported",
790           root->DebugString());
791     }
792     case HloOpcode::kGather: {
793       return PostorderDFSNode()
794           .AddDependency(root->operand_ids(0),
795                          PostorderDFSNodeType::kConstantUpperBound, context)
796           .AddDependency(root->operand_ids(1),
797                          PostorderDFSNodeType::kConstantValue, context)
798           .AddVisit([root, this](absl::Span<Literal> operands) {
799             return HloProtoEvaluator(evaluator, *root)
800                 .WithOperands(operands)
801                 .Evaluate();
802           });
803     }
804     default:
805       return AnalyzeConstantValueFallback(
806           handle, PostorderDFSNodeType::kConstantUpperBound, context);
807   }
808 }
809 
AnalyzeLowerBound(int64_t handle,InferenceContext context)810 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeLowerBound(
811     int64_t handle, InferenceContext context) {
812   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
813                       handle_to_instruction(handle));
814   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
815   Shape subshape =
816       ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
817   if (IsInstructionOverLimit(root, context)) {
818     return CreateAllDynamicResult(subshape,
819                                   PostorderDFSNodeType::kConstantLowerBound);
820   }
821   switch (opcode) {
822     case HloOpcode::kGetDimensionSize: {
823       int64_t dimension = root->dimensions(0);
824       int64_t operand_handle = root->operand_ids(0);
825       TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
826                           handle_to_instruction(operand_handle));
827       return PostorderDFSNode().AddVisit(
828           [dimension, operand_proto]() -> StatusOr<Literal> {
829             if (operand_proto->shape().is_dynamic_dimension(dimension)) {
830               return LiteralUtil::CreateR0<int32_t>(0);
831             } else {
832               return LiteralUtil::CreateR0<int32_t>(
833                   operand_proto->shape().dimensions(dimension));
834             }
835           });
836     }
837     case HloOpcode::kAbs: {
838       // lower-bound(abs(operand)) = min(abs(lower-bound(operand)),
839       // abs(upper-bound(operand)))
840       return PostorderDFSNode()
841           .AddDependency(root->operand_ids(0),
842                          PostorderDFSNodeType::kConstantLowerBound, context)
843           .AddDependency(root->operand_ids(0),
844                          PostorderDFSNodeType::kConstantUpperBound, context)
845           .AddVisit([this](Literal lower_bound,
846                            Literal upper_bound) -> StatusOr<Literal> {
847             TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
848                                 evaluator.EvaluateElementwiseUnaryOp(
849                                     HloOpcode::kAbs, lower_bound));
850             TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
851                                 evaluator.EvaluateElementwiseUnaryOp(
852                                     HloOpcode::kAbs, upper_bound));
853             return evaluator.EvaluateElementwiseBinaryOp(
854                 HloOpcode::kMinimum, lower_bound_abs, upper_bound_abs);
855           });
856     }
857     case HloOpcode::kNegate: {
858       // lower-bound(negate(operand)) = negate(upper-bound(operand))
859       return PostorderDFSNode()
860           .AddDependency(root->operand_ids(0),
861                          PostorderDFSNodeType::kConstantUpperBound, context)
862           .AddVisit([this](Literal upper_bound) -> StatusOr<Literal> {
863             return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
864                                                         upper_bound);
865           });
866     }
867     case HloOpcode::kSubtract:
868     case HloOpcode::kDivide: {
869       // Upper bound is used for second operand of subtract and divide.
870       return PostorderDFSNode()
871           .AddDependency(root->operand_ids(0),
872                          PostorderDFSNodeType::kConstantLowerBound, context)
873           .AddDependency(root->operand_ids(1),
874                          PostorderDFSNodeType::kConstantUpperBound, context)
875           .AddVisit(
876               [root, this](absl::Span<Literal> operands) -> StatusOr<Literal> {
877                 return HloProtoEvaluator(evaluator, *root)
878                     .WithOperands(operands)
879                     .Evaluate();
880               });
881     }
882     case HloOpcode::kGather: {
883       return PostorderDFSNode()
884           .AddDependency(root->operand_ids(0),
885                          PostorderDFSNodeType::kConstantLowerBound, context)
886           .AddDependency(root->operand_ids(1),
887                          PostorderDFSNodeType::kConstantValue, context)
888           .AddVisit([root, this](absl::Span<Literal> operands) {
889             return HloProtoEvaluator(evaluator, *root)
890                 .WithOperands(operands)
891                 .Evaluate();
892           });
893     }
894     default:
895       return AnalyzeConstantValueFallback(
896           handle, PostorderDFSNodeType::kConstantLowerBound, context);
897   }
898 }
899 
AnalyzeConstant(int64_t handle,InferenceContext context)900 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstant(
901     int64_t handle, InferenceContext context) {
902   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
903                       handle_to_instruction(handle));
904   HloOpcode opcode = StringToHloOpcode(root->opcode()).ValueOrDie();
905   Shape subshape =
906       ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
907   if (IsInstructionOverLimit(root, context)) {
908     return CreateAllDynamicResult(subshape,
909                                   PostorderDFSNodeType::kConstantValue);
910   }
911   switch (opcode) {
912     case HloOpcode::kGetDimensionSize: {
913       int64_t dimension = root->dimensions(0);
914       int64_t operand_handle = root->operand_ids(0);
915       TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
916                           handle_to_instruction(operand_handle));
917       return PostorderDFSNode().AddVisit(
918           [operand_proto, dimension, root]() -> StatusOr<Literal> {
919             if (operand_proto->shape().is_dynamic_dimension(dimension)) {
920               // The value is dynamic, we return garbage data here and mask them
921               // out later.
922               return CreateGarbageLiteral(Shape(root->shape()));
923             } else {
924               return LiteralUtil::CreateR0<int32_t>(
925                   operand_proto->shape().dimensions(dimension));
926             }
927           });
928     }
929     case HloOpcode::kSubtract:
930     case HloOpcode::kCos:
931     case HloOpcode::kSin:
932     case HloOpcode::kNegate:
933     case HloOpcode::kAbs:
934     case HloOpcode::kDivide: {
935       PostorderDFSNode result;
936       for (auto operand_id : root->operand_ids()) {
937         result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
938                              context);
939       }
940       return result.AddVisit(
941           [root, this](absl::Span<Literal> operands) -> StatusOr<Literal> {
942             return HloProtoEvaluator(evaluator, *root)
943                 .WithOperands(operands)
944                 .Evaluate();
945           });
946     }
947     case HloOpcode::kCustomCall: {
948       if (root->custom_call_target() == "SetBound") {
949         // `SetBound` doesn't change the static value of a tensor, so forward
950         // the operand when analyzing static value.
951         return PostorderDFSNode()
952             .AddDependency(root->operand_ids(0),
953                            PostorderDFSNodeType::kConstantValue, context)
954             .AddVisit(
955                 [](Literal operand) -> StatusOr<Literal> { return operand; });
956       } else if (root->custom_call_target() == "Sharding") {
957         return PostorderDFSNode()
958             .AddDependency(root->operand_ids(0),
959                            PostorderDFSNodeType::kConstantValue, context)
960             .AddVisit([](Literal operand) { return operand; });
961       } else {
962         return PostorderDFSNode().AddVisit(
963             [root, context](absl::Span<Literal>) {
964               // The value is dynamic. We return a garbage literal here, which
965               // will be masked out later.
966               return CreateGarbageLiteral(ShapeUtil::GetSubshape(
967                   Shape(root->shape()), context.shape_index));
968             });
969       }
970     }
971     case HloOpcode::kSort: {
972       PostorderDFSNode result;
973       InferenceContext dep_context = context;
974       dep_context.shape_index = {};
975       for (auto operand_id : root->operand_ids()) {
976         result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
977                              dep_context);
978       }
979       const HloComputationProto* computation_proto =
980           handle_to_computation(root->called_computation_ids(0));
981       return result.AddVisit(
982           [root, context, computation_proto,
983            this](absl::Span<Literal> operands) -> StatusOr<Literal> {
984             TF_ASSIGN_OR_RETURN(
985                 auto computation,
986                 HloComputation::CreateFromProto(*computation_proto, {}));
987             return HloProtoEvaluator(evaluator, *root)
988                 .WithOperands(operands)
989                 .WithComputation(std::move(computation))
990                 .WithSubshape(context.shape_index)
991                 .Evaluate();
992           });
993     }
994     default:
995       return AnalyzeConstantValueFallback(
996           handle, PostorderDFSNodeType::kConstantValue, context);
997   }
998 }
999 
AnalyzeIsDynamic(int64_t handle,PostorderDFSNodeType type,InferenceContext context)1000 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeIsDynamic(
1001     int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
1002   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
1003                       handle_to_instruction(handle));
1004   // Invariant check.
1005   TF_RET_CHECK(root);
1006   VLOG(1) << "Analyzing IsDynamic on " << root->DebugString();
1007   Shape subshape =
1008       ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
1009   if (IsInstructionOverLimit(root, context)) {
1010     return CreateAllDynamicResult(subshape, type);
1011   }
1012   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
1013   PostorderDFSNode result;
1014   for (auto operand_id : root->operand_ids()) {
1015     InferenceContext dep_context = context;
1016     dep_context.shape_index = {};
1017     result.AddDependency(operand_id, type, dep_context);
1018   }
1019   switch (opcode) {
1020     case HloOpcode::kGetDimensionSize: {
1021       int64_t dimension = root->dimensions(0);
1022       int64_t operand_handle = root->operand_ids(0);
1023       TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
1024                           handle_to_instruction(operand_handle));
1025       return PostorderDFSNode().AddVisit(
1026           [operand_proto, dimension, type]() -> StatusOr<Literal> {
1027             if (type == PostorderDFSNodeType::kBoundIsDynamic) {
1028               // The bound of dynamic dimension is not dynamic.
1029               return LiteralUtil::CreateR0<bool>(false);
1030             }
1031             // The value of dynamic dimension is dynamic.
1032             return LiteralUtil::CreateR0<bool>(
1033                 operand_proto->shape().is_dynamic_dimension(dimension));
1034           });
1035     }
1036     case HloOpcode::kSort: {
1037       auto dfs = PostorderDFSNode();
1038       InferenceContext dep_context = context;
1039       dep_context.shape_index = {};
1040 
1041       for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
1042         dfs.AddDependency(root->operand_ids(i), type, dep_context);
1043       }
1044 
1045       return dfs.AddVisit([root, context, type](absl::Span<Literal> operands)
1046                               -> StatusOr<Literal> {
1047         bool all_operands_values_static = true;
1048         for (int64_t i = 0; i < operands.size(); ++i) {
1049           all_operands_values_static &= operands[i].IsAll(0);
1050         }
1051         if (type == PostorderDFSNodeType::kValueIsDynamic) {
1052           // If there is a single operand of a sort is dynamic, we
1053           // conservatively say all results are dynamic.
1054           return CreatePredLiteral(!all_operands_values_static,
1055                                    ShapeUtil::GetSubshape(Shape(root->shape()),
1056                                                           context.shape_index));
1057         }
1058         CHECK(type == PostorderDFSNodeType::kBoundIsDynamic);
1059         // The condition for bounds are more relaxed than values. If we know the
1060         // bounds of each element [B0, B1... Bn], all results have the same
1061         // bound
1062         // [max(B0, B1...), max(B0, B1...), ...]
1063         if (!context.shape_index.empty()) {
1064           int64_t index = context.shape_index[0];
1065           bool all_values_static = operands[index].IsAll(0);
1066           return CreatePredLiteral(!all_values_static, operands[index].shape());
1067         }
1068 
1069         std::vector<Literal> results;
1070         results.reserve(operands.size());
1071         for (int64_t i = 0; i < operands.size(); ++i) {
1072           bool all_values_static = operands[i].IsAll(0);
1073           results.emplace_back(
1074               CreatePredLiteral(!all_values_static, operands[i].shape()));
1075         }
1076         if (!ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index)
1077                  .IsTuple()) {
1078           return std::move(results[0]);
1079         }
1080         return LiteralUtil::MakeTupleOwned(std::move(results));
1081       });
1082     }
1083     case HloOpcode::kSetDimensionSize:
1084       return result.AddVisit([root, type](absl::Span<Literal> operands) {
1085         bool any_dynamic_operand = absl::c_any_of(
1086             operands, [](Literal& operand) { return !operand.IsAll(0); });
1087         // If values in a tensor `t` with bound are [e0, e1, e2...], we can say
1088         // the max value of each position is [max(t), max(t), max(t), ...]. The
1089         // effective size of this tensor doesn't change the max value.
1090         return CreatePredLiteral(
1091             type == PostorderDFSNodeType::kValueIsDynamic &&
1092                 any_dynamic_operand,
1093             ShapeUtil::MakeStaticShape(Shape(root->shape())));
1094       });
1095     case HloOpcode::kDynamicSlice: {
1096       return result.AddVisit([root](absl::Span<Literal> operands) {
1097         // If any of the operand is dynamic, we say output is dynamic.
1098         bool any_dynamic_operand = absl::c_any_of(
1099             operands, [](Literal& operand) { return !operand.IsAll(0); });
1100         return CreatePredLiteral(any_dynamic_operand, Shape(root->shape()));
1101       });
1102     }
1103     case HloOpcode::kAbs:
1104     case HloOpcode::kRoundNearestAfz:
1105     case HloOpcode::kRoundNearestEven:
1106     case HloOpcode::kBitcast:
1107     case HloOpcode::kCeil:
1108     case HloOpcode::kCollectivePermuteDone:
1109     case HloOpcode::kCos:
1110     case HloOpcode::kClz:
1111     case HloOpcode::kExp:
1112     case HloOpcode::kExpm1:
1113     case HloOpcode::kFloor:
1114     case HloOpcode::kImag:
1115     case HloOpcode::kIsFinite:
1116     case HloOpcode::kLog:
1117     case HloOpcode::kLog1p:
1118     case HloOpcode::kNot:
1119     case HloOpcode::kNegate:
1120     case HloOpcode::kPopulationCount:
1121     case HloOpcode::kReal:
1122     case HloOpcode::kRsqrt:
1123     case HloOpcode::kLogistic:
1124     case HloOpcode::kSign:
1125     case HloOpcode::kSin:
1126     case HloOpcode::kConvert:
1127     case HloOpcode::kSqrt:
1128     case HloOpcode::kCbrt:
1129     case HloOpcode::kTanh: {
1130       // Forward operand as they don't change if a value is dynamic or static.
1131       return result.AddVisit([](Literal operand) { return operand; });
1132     }
1133     case HloOpcode::kAdd:
1134     case HloOpcode::kAtan2:
1135     case HloOpcode::kDivide:
1136     case HloOpcode::kComplex:
1137     case HloOpcode::kMaximum:
1138     case HloOpcode::kMinimum:
1139     case HloOpcode::kMultiply:
1140     case HloOpcode::kPower:
1141     case HloOpcode::kRemainder:
1142     case HloOpcode::kSubtract:
1143     case HloOpcode::kCompare:
1144     case HloOpcode::kAnd:
1145     case HloOpcode::kOr:
1146     case HloOpcode::kXor:
1147     case HloOpcode::kShiftLeft:
1148     case HloOpcode::kShiftRightArithmetic:
1149     case HloOpcode::kShiftRightLogical: {
1150       return result.AddVisit([root, this](absl::Span<Literal> operands) {
1151         return HloProtoEvaluator(evaluator, *root)
1152             .WithOperands(operands)
1153             .WithPrimitiveType(PRED)
1154             .WithOpCode(HloOpcode::kOr)
1155             .Evaluate();
1156       });
1157     }
1158     case HloOpcode::kTuple:
1159     case HloOpcode::kTranspose:
1160     case HloOpcode::kSlice:
1161     case HloOpcode::kBroadcast:
1162     case HloOpcode::kReverse:
1163     case HloOpcode::kConcatenate:
1164     case HloOpcode::kReshape:
1165     case HloOpcode::kPad: {
1166       if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
1167         // There could be many operands of a tuple, but only one that we are
1168         // interested in, represented by `tuple_operand_index`.
1169         int64_t tuple_operand_index = context.shape_index.front();
1170         InferenceContext tuple_operand_context = context;
1171         tuple_operand_context.shape_index.pop_front();
1172         return PostorderDFSNode()
1173             .AddDependency(root->operand_ids(tuple_operand_index), type,
1174                            tuple_operand_context)
1175             .AddVisit([](Literal operand) { return operand; });
1176       }
1177       return result.AddVisit([root, this](absl::Span<Literal> operands) {
1178         return HloProtoEvaluator(evaluator, *root)
1179             .WithOperands(operands)
1180             .WithPrimitiveType(PRED)
1181             .Evaluate();
1182       });
1183     }
1184     case HloOpcode::kCall: {
1185       auto node = PostorderDFSNode();
1186       auto* call_proto = root;
1187 
1188       if (call_proto->operand_ids_size() != 1) {
1189         // Only support single operand forwarding.
1190         return CreateAllDynamicResult(subshape, type);
1191       }
1192       int64_t call_root =
1193           handle_to_computation(call_proto->called_computation_ids(0))
1194               ->root_id();
1195       InferenceContext branch_context = context;
1196       branch_context.caller_operand_handles.push_back(
1197           call_proto->operand_ids(0));
1198       node.AddDependency(call_root, PostorderDFSNodeType::kValueIsDynamic,
1199                          branch_context, "callee's root instruction");
1200       return node.AddVisit([context](Literal operand) -> StatusOr<Literal> {
1201         // Forward result of callee's root to caller.
1202         return operand;
1203       });
1204     }
1205     case HloOpcode::kConditional: {
1206       auto node = PostorderDFSNode();
1207       auto* conditional_proto = root;
1208       InferenceContext predicate_context = context;
1209       predicate_context.shape_index = {};
1210       // Add dependencies to analyze the predicate of the conditional.
1211       node.AddDependency(conditional_proto->operand_ids(0),
1212                          PostorderDFSNodeType::kConstantValue,
1213                          predicate_context)
1214           .AddDependency(conditional_proto->operand_ids(0),
1215                          PostorderDFSNodeType::kValueIsDynamic,
1216                          predicate_context);
1217       const int64_t branch_size =
1218           conditional_proto->called_computation_ids_size();
1219       for (int64_t i = 0; i < branch_size; ++i) {
1220         int64_t branch_root =
1221             handle_to_computation(conditional_proto->called_computation_ids(i))
1222                 ->root_id();
1223         InferenceContext branch_context = context;
1224         branch_context.caller_operand_handles.push_back(
1225             conditional_proto->operand_ids(i + 1));
1226         node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
1227                            branch_context,
1228                            absl::StrFormat("branch %lld's value", i))
1229             .AddDependency(branch_root, PostorderDFSNodeType::kValueIsDynamic,
1230                            branch_context,
1231                            absl::StrFormat("branch %lld's dynamism", i));
1232       }
1233       // Predicate uses 2 dependencies:
1234       // 0: Predicate value.
1235       // 1: Predicate is dynamic.
1236       // Each branch i has 2 dependenices:
1237       // 2*i: Branch result value
1238       // 2*i + 1: Branch value is dynamic.
1239       return node.AddVisit([root, branch_size,
1240                             context](absl::Span<Literal> operands)
1241                                -> StatusOr<Literal> {
1242         int64_t pred_is_dynamic = operands[1].Get<bool>({});
1243         auto result = CreatePredLiteral(
1244             true,
1245             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1246         if (pred_is_dynamic) {
1247           VLOG(1) << "predict is dynamic value" << result.ToString();
1248           // If predicate is dynamic, the result is only static if all
1249           // branches are static and return the same value.
1250           result.MutableEachCell<bool>(
1251               [&](absl::Span<const int64_t> indices, bool value) {
1252                 std::string branch_value = operands[2].GetAsString(indices, {});
1253                 for (int64_t i = 0; i < branch_size; ++i) {
1254                   const int64_t branch_value_index = 2 + 2 * i;
1255                   const int64_t branch_dynamism_index = 2 + 2 * i + 1;
1256                   auto branch_is_dynamic =
1257                       operands[branch_dynamism_index].Get<bool>(indices);
1258                   if (branch_is_dynamic) {
1259                     return true;
1260                   }
1261 
1262                   if (branch_value !=
1263                       operands[branch_value_index].GetAsString(indices, {})) {
1264                     return true;
1265                   }
1266                 }
1267                 // Value of the branch is static.
1268                 return false;
1269               });
1270           return result;
1271         } else {
1272           VLOG(1) << "predict is constant value";
1273           // If predicate is static, return true if given branch result
1274           // value is dynamic.
1275           int64_t branch_index = 0;
1276           if (operands[0].shape().element_type() == PRED) {
1277             if (operands[0].Get<bool>({})) {
1278               branch_index = 0;
1279             } else {
1280               branch_index = 1;
1281             }
1282           } else {
1283             branch_index = operands[0].GetIntegralAsS64({}).value();
1284           }
1285           const int64_t branch_dynamism_index = 2 + 2 * branch_index + 1;
1286           return std::move(operands[branch_dynamism_index]);
1287         }
1288       });
1289     }
1290     case HloOpcode::kGetTupleElement: {
1291       int64_t operand_handle = root->operand_ids(0);
1292       PostorderDFSNode result;
1293       context.shape_index.push_front(root->tuple_index());
1294       return PostorderDFSNode()
1295           .AddDependency(operand_handle, type, context)
1296           .AddVisit([](Literal operand) { return operand; });
1297     }
1298 
1299     case HloOpcode::kReduce: {
1300       return result.AddVisit(
1301           [root, context, this](absl::Span<Literal> operands) {
1302             Shape root_shape = Shape(root->shape());
1303             Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
1304             std::unique_ptr<HloComputation> reduce_or;
1305             if (root_shape.IsTuple()) {
1306               // Variadic reduce.
1307               HloComputation::Builder b("reduce_or");
1308               // Assuming all operands interact with each other. This could be
1309               // overly conservative.  If needed, a dataflow analysis could be
1310               // performed in the future.
1311               //
1312               // The value starts with `false` (static) and will be `or`ed with
1313               // all operands's dynamism.
1314               auto accum = b.AddInstruction(HloInstruction::CreateConstant(
1315                   LiteralUtil::CreateR0<bool>(false)));
1316 
1317               for (int i = 0; i < root_shape.tuple_shapes_size(); ++i) {
1318                 auto lhs = b.AddInstruction(
1319                     HloInstruction::CreateParameter(i, scalar_shape, "lhs"));
1320                 auto rhs = b.AddInstruction(HloInstruction::CreateParameter(
1321                     i + root_shape.tuple_shapes_size(), scalar_shape, "rhs"));
1322                 accum = b.AddInstruction(HloInstruction::CreateBinary(
1323                     scalar_shape, HloOpcode::kOr, accum, lhs));
1324                 accum = b.AddInstruction(HloInstruction::CreateBinary(
1325                     scalar_shape, HloOpcode::kOr, accum, rhs));
1326               }
1327               // `Broadcast` the result to all positions in the result.
1328               std::vector<HloInstruction*> results(
1329                   root_shape.tuple_shapes_size(), accum);
1330               b.AddInstruction(HloInstruction::CreateTuple(results));
1331               reduce_or = b.Build();
1332             } else {
1333               HloComputation::Builder b("reduce_or");
1334               auto lhs = b.AddInstruction(
1335                   HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1336               auto rhs = b.AddInstruction(
1337                   HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1338               b.AddInstruction(HloInstruction::CreateBinary(
1339                   scalar_shape, HloOpcode::kOr, lhs, rhs));
1340               reduce_or = b.Build();
1341             }
1342 
1343             return HloProtoEvaluator(evaluator, *root)
1344                 .WithOperands(operands)
1345                 .WithPrimitiveType(PRED)
1346                 .WithComputation(std::move(reduce_or))
1347                 // Reduce could produce tuple shape, only fetch what we need.
1348                 .WithSubshape(context.shape_index)
1349                 .Evaluate();
1350           });
1351     }
1352     case HloOpcode::kConstant:
1353     case HloOpcode::kIota: {
1354       return result.AddVisit(
1355           [root]() { return CreatePredLiteral(false, Shape(root->shape())); });
1356     }
1357     case HloOpcode::kParameter: {
1358       if (opcode == HloOpcode::kParameter &&
1359           !context.caller_operand_handles.empty()) {
1360         int64_t caller_operand = context.caller_operand_handles.back();
1361         context.caller_operand_handles.pop_back();
1362         return result.AddDependency(caller_operand, type, context)
1363             .AddVisit([](Literal literal) { return literal; });
1364       }
1365       return result.AddVisit([root, context]() {
1366         return CreatePredLiteral(
1367             true,
1368             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1369       });
1370     }
1371     case HloOpcode::kSelect: {
1372       return PostorderDFSNode()
1373           .AddDependency(root->operand_ids(0),
1374                          PostorderDFSNodeType::kConstantValue, context)
1375           .AddDependency(root->operand_ids(0),
1376                          PostorderDFSNodeType::kValueIsDynamic, context)
1377           // lhs dependency.
1378           .AddDependency(root->operand_ids(1), type, context)
1379           // rhs dependency.
1380           .AddDependency(root->operand_ids(2), type, context)
1381           .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
1382             OptionalLiteral optional_selector_literal(std::move(operands[0]),
1383                                                       std::move(operands[1]));
1384             Literal lhs = std::move(operands[2]);
1385             Literal rhs = std::move(operands[3]);
1386             auto result = CreatePredLiteral(true, Shape(root->shape()));
1387             result.MutableEachCell<bool>(
1388                 [&](absl::Span<const int64_t> indices, bool value) {
1389                   std::optional<bool> optional_selector =
1390                       optional_selector_literal.Get<bool>(indices);
1391 
1392                   bool lhs_value = lhs.Get<bool>(indices);
1393                   bool rhs_value = rhs.Get<bool>(indices);
1394                   if (optional_selector.has_value()) {
1395                     // Manually evaluate the selection without using Evaluator.
1396                     if (*optional_selector) {
1397                       return lhs_value;
1398                     } else {
1399                       return rhs_value;
1400                     }
1401                   } else {
1402                     // Conservatively assume value is dynamic if selector is
1403                     // dynamic.
1404                     return true;
1405                   }
1406                 });
1407             return result;
1408           });
1409     }
1410     case HloOpcode::kGather: {
1411       return PostorderDFSNode()
1412           .AddDependency(root->operand_ids(0), type, context)
1413           .AddDependency(root->operand_ids(1),
1414                          PostorderDFSNodeType::kConstantValue, context)
1415           .AddDependency(root->operand_ids(1),
1416                          PostorderDFSNodeType::kValueIsDynamic, context)
1417           .AddVisit(
1418               [root, this](absl::Span<Literal> operands) -> StatusOr<Literal> {
1419                 OptionalLiteral optional_selector_literal(
1420                     std::move(operands[1]), std::move(operands[2]));
1421 
1422                 if (!optional_selector_literal.AllValid()) {
1423                   // Conservatively assume results are dynamic.
1424                   return CreatePredLiteral(true, Shape(root->shape()));
1425                 }
1426                 std::vector<Literal> new_operands;
1427                 new_operands.emplace_back(std::move(operands[0]));
1428                 new_operands.emplace_back(
1429                     optional_selector_literal.GetValue()->Clone());
1430 
1431                 return HloProtoEvaluator(evaluator, *root)
1432                     .WithOperands(absl::MakeSpan(new_operands))
1433                     .WithPrimitiveType(PRED)
1434                     .Evaluate();
1435               });
1436     }
1437     case HloOpcode::kCustomCall: {
1438       if (root->custom_call_target() == "SetBound") {
1439         return PostorderDFSNode().AddVisit([type, root]() -> StatusOr<Literal> {
1440           if (type == PostorderDFSNodeType::kBoundIsDynamic) {
1441             return CreatePredLiteral(false, Shape(root->shape()));
1442           } else {
1443             if (root->literal().shape().element_type() == TUPLE) {
1444               // First literal of SetBound contains bounds, second literal
1445               // contains dynamism indicators.
1446               return Literal::CreateFromProto(
1447                   root->literal().tuple_literals(1));
1448             } else if (type == PostorderDFSNodeType::kValueIsDynamic) {
1449               return CreatePredLiteral(true, Shape(root->shape()));
1450             } else {
1451               return Literal::CreateFromProto(root->literal());
1452             }
1453           }
1454         });
1455       } else if (root->custom_call_target() == "Sharding") {
1456         return result.AddVisit([](Literal operand) { return operand; });
1457       } else {
1458         return InvalidArgument(
1459             "Dynamic inferencing on custom call %s is not supported",
1460             root->DebugString());
1461       }
1462 
1463       break;
1464     }
1465 
1466     case HloOpcode::kRecv:
1467     case HloOpcode::kRecvDone:
1468     case HloOpcode::kSend:
1469     case HloOpcode::kSendDone:
1470     case HloOpcode::kWhile: {
1471       return PostorderDFSNode().AddVisit([root,
1472                                           context]() -> StatusOr<Literal> {
1473         return CreatePredLiteral(
1474             true,
1475             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1476       });
1477       break;
1478     }
1479     default:
1480       return PostorderDFSNode().AddVisit([root,
1481                                           context]() -> StatusOr<Literal> {
1482         return CreatePredLiteral(
1483             true,
1484             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1485       });
1486   }
1487 }
1488 
PostOrderDFSVisit(int64_t handle,PostorderDFSNodeType type)1489 StatusOr<Literal> PostorderDFSVisitor::PostOrderDFSVisit(
1490     int64_t handle, PostorderDFSNodeType type) {
1491   enum VisitState {
1492     kUnvisited = 0,
1493     kVisiting,
1494     kVisited,
1495   };
1496 
1497   int64_t unique_id = 0;
1498   struct WorkItem {
1499     explicit WorkItem(int64_t handle, InferenceContext context,
1500                       PostorderDFSNodeType type, VisitState state, int64_t id)
1501         : handle(handle),
1502           context(std::move(context)),
1503           type(type),
1504           state(state),
1505           id(id) {}
1506     int64_t handle;  // Handle of the node in the graph.
1507     InferenceContext context;
1508     PostorderDFSNodeType type;
1509     VisitState state;
1510     Visit visit;  // The handler to call once the dependencies are resolved into
1511                   // literal form.
1512     int64_t id;   // Unique id in the work queue, starting from 0.
1513     std::vector<CacheKey> dependencies;
1514 
1515     CacheKey GetCacheKey() { return CacheKey(handle, context, type); }
1516   };
1517 
1518   std::vector<WorkItem> stack;
1519   WorkItem root(handle, InferenceContext({}, {}), type, kUnvisited,
1520                 unique_id++);
1521   stack.push_back(root);
1522   while (!stack.empty()) {
1523     WorkItem& item = stack.back();
1524     VLOG(1) << "stack top shape index: " << item.context.shape_index.ToString();
1525     if (VLOG_IS_ON(1)) {
1526       TF_RETURN_IF_ERROR(handle_to_instruction(item.handle).status());
1527       VLOG(1) << "stack top "
1528               << handle_to_instruction(item.handle).ValueOrDie()->DebugString();
1529     }
1530     if (item.state == kVisiting) {
1531       VLOG(1) << "visiting";
1532       // The dependencies are ready, visit the node itself.
1533 
1534       // Gather dependencies and transform them into literals.
1535       std::vector<Literal> literals;
1536       literals.reserve(item.dependencies.size());
1537       for (CacheKey& dep_key : item.dependencies) {
1538         TF_RET_CHECK(evaluated.contains(dep_key));
1539         literals.emplace_back(evaluated.at(dep_key).Clone());
1540       }
1541       VLOG(1) << "Start visiting with dependency type: "
1542               << PostorderDFSNodeTypeToString(item.type);
1543       TF_ASSIGN_OR_RETURN(auto literal, item.visit(absl::MakeSpan(literals)));
1544       VLOG(1) << "End visiting: " << literal.ToString();
1545       evaluated[item.GetCacheKey()] = std::move(literal);
1546       stack.pop_back();
1547       continue;
1548     }
1549     // This is the first time we see this node, we want to gather its
1550     // dependenceis.
1551     VLOG(1) << "unvisited";
1552     if (evaluated.contains(item.GetCacheKey())) {
1553       stack.pop_back();
1554       continue;
1555     }
1556     item.state = kVisiting;
1557     PostorderDFSNode node;
1558     switch (item.type) {
1559       case PostorderDFSNodeType::kConstantValue: {
1560         VLOG(1) << "constant value";
1561         TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle, item.context));
1562         break;
1563       }
1564       case PostorderDFSNodeType::kConstantLowerBound: {
1565         VLOG(1) << "constant lower bound";
1566         TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle, item.context));
1567         break;
1568       }
1569       case PostorderDFSNodeType::kConstantUpperBound: {
1570         VLOG(1) << "constant upper bound";
1571         TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle, item.context));
1572         break;
1573       }
1574       case PostorderDFSNodeType::kBoundIsDynamic:
1575       case PostorderDFSNodeType::kValueIsDynamic: {
1576         VLOG(1) << "value is dynamic";
1577         TF_ASSIGN_OR_RETURN(
1578             node, AnalyzeIsDynamic(item.handle, item.type, item.context));
1579         break;
1580       }
1581     }
1582     // Store the visit function which is needed when its dependencies are
1583     // resolved.
1584     item.visit = node.visit;
1585 
1586     const int64_t current_item_id = stack.size() - 1;
1587     // Enqueue dependencies into the stack. `item` shouldn't be accessed after
1588     // this point.
1589     for (const PostorderDFSDep& dep : node.dependencies) {
1590       TF_ASSIGN_OR_RETURN(auto dependency_inst,
1591                           handle_to_instruction(dep.handle));
1592       VLOG(1) << "dependency " << dep.annotation
1593               << "::" << dependency_inst->DebugString() << "index"
1594               << dep.context.shape_index << " stack size:" << stack.size();
1595       stack.emplace_back(dep.handle, dep.context, dep.type, kUnvisited,
1596                          unique_id++);
1597       stack[current_item_id].dependencies.push_back(stack.back().GetCacheKey());
1598     }
1599   }
1600   VLOG(1) << "done" << evaluated[root.GetCacheKey()].ToString();
1601   return evaluated[root.GetCacheKey()].Clone();
1602 }
1603 
AnalyzeIsDynamic(XlaOp op)1604 StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
1605   PostorderDFSVisitor visitor(
1606       evaluator_,
1607       [&](int64_t handle) {
1608         return builder_->LookUpInstructionByHandle(handle);
1609       },
1610       [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1611 
1612   auto result = visitor.PostOrderDFSVisit(
1613       op.handle(), PostorderDFSNodeType::kValueIsDynamic);
1614   return result;
1615 }
1616 
CseOpHandle(int64_t handle)1617 StatusOr<std::optional<int64_t>> ValueInference::CseOpHandle(int64_t handle) {
1618   TF_ASSIGN_OR_RETURN(auto inst, builder_->LookUpInstructionByHandle(handle));
1619   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode()));
1620   // For now, only handle kGetDimensionSize as that's the most duplicated one.
1621   if (opcode != HloOpcode::kGetDimensionSize) {
1622     return {std::nullopt};
1623   }
1624   int64_t hash = absl::HashOf(inst->operand_ids(0), inst->dimensions(0));
1625   auto lookup = cse_map_.find(hash);
1626   if (lookup == cse_map_.end()) {
1627     cse_map_[hash] = handle;
1628     return {std::nullopt};
1629   }
1630   TF_ASSIGN_OR_RETURN(auto equivalent_op,
1631                       builder_->LookUpInstructionByHandle(lookup->second));
1632   // Check that the op is indeed equivalent to prevent hash collision --
1633   // relatively easy to happen with 64 bits hash.
1634   if (equivalent_op->opcode() != inst->opcode() ||
1635       equivalent_op->operand_ids(0) != inst->operand_ids(0) ||
1636       equivalent_op->dimensions(0) != inst->dimensions(0)) {
1637     // Hash collision, don't CSE.
1638     return {std::nullopt};
1639   }
1640   int64_t cse = lookup->second;
1641   if (handle != cse) {
1642     // Successfully found a handle that's not the same as input but equivalent.
1643     return {cse};
1644   }
1645   return {std::nullopt};
1646 }
1647 
SimplifyOp(int64_t handle)1648 StatusOr<Literal> ValueInference::SimplifyOp(int64_t handle) {
1649   TF_ASSIGN_OR_RETURN(auto cse_handle, CseOpHandle(handle));
1650   if (cse_handle) {
1651     // Use the CSE'd handle instead.
1652     return SimplifyOp(*cse_handle);
1653   }
1654   TF_ASSIGN_OR_RETURN(auto* inst, builder_->LookUpInstructionByHandle(handle));
1655   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode()));
1656   std::vector<Literal> operands;
1657   auto output_shape = Shape(inst->shape());
1658   switch (opcode) {
1659     case HloOpcode::kSlice:
1660     case HloOpcode::kConcatenate:
1661     case HloOpcode::kReshape:
1662     case HloOpcode::kBroadcast: {
1663       for (auto operand_id : inst->operand_ids()) {
1664         TF_ASSIGN_OR_RETURN(auto literal, SimplifyOp(operand_id));
1665         operands.emplace_back(std::move(literal));
1666       }
1667       // We put handles into the tensor and evaluate the results into a literal.
1668       // The literal also contain handles for each element position.
1669       return HloProtoEvaluator(evaluator_, *inst)
1670           .WithOperands(absl::MakeSpan(operands))
1671           .WithPrimitiveType(S64)
1672           .Evaluate();
1673     }
1674     case HloOpcode::kConvert: {
1675       // Only identity kConvert can be optimized away.
1676       auto operand = builder_->LookUpInstructionByHandle(inst->operand_ids(0))
1677                          .ValueOrDie();
1678       if (Shape::Equal()(output_shape, Shape(operand->shape()))) {
1679         // Forward operand handle as result.
1680         return SimplifyOp(inst->operand_ids(0));
1681       } else {
1682         return CreateS64Literal(-1, output_shape);
1683       }
1684     }
1685     case HloOpcode::kAdd: {
1686       // a + (b - a) => b
1687       // a + b + (c - a) => b + c
1688       if (output_shape.rank() == 0) {
1689         TF_ASSIGN_OR_RETURN(auto lhs, SimplifyOp(inst->operand_ids(0)));
1690         TF_ASSIGN_OR_RETURN(auto rhs, SimplifyOp(inst->operand_ids(1)));
1691         int64_t lhs_handle = lhs.Get<int64_t>({});
1692         int64_t rhs_handle = rhs.Get<int64_t>({});
1693         if (lhs_handle == -1 || rhs_handle == -1) {
1694           return CreateS64Literal(-1, output_shape);
1695         }
1696         // Recursive lambda needs explicit signature.
1697         std::function<std::optional<int64_t>(int64_t, int64_t)>
1698             can_be_optimized;
1699         can_be_optimized = [this, &can_be_optimized](
1700                                int64_t lhs,
1701                                int64_t rhs) -> std::optional<int64_t> {
1702           auto rhs_inst = builder_->LookUpInstructionByHandle(rhs).ValueOrDie();
1703           HloOpcode rhs_opcode =
1704               StringToHloOpcode(rhs_inst->opcode()).ValueOrDie();
1705           if (rhs_opcode == HloOpcode::kSubtract) {
1706             auto sub_lhs_handle = SimplifyOp(rhs_inst->operand_ids(0))
1707                                       .ValueOrDie()
1708                                       .Get<int64_t>({});
1709             auto sub_rhs_handle = SimplifyOp(rhs_inst->operand_ids(1))
1710                                       .ValueOrDie()
1711                                       .Get<int64_t>({});
1712             if (sub_rhs_handle == lhs) {
1713               // lhs + (sub_lhs - sub_rhs) = sub_lhs if lhs == sub_rhs
1714               return sub_lhs_handle;
1715             }
1716           }
1717 
1718           // Check the case for a + b + (c - a) => b + c
1719           auto lhs_inst = builder_->LookUpInstructionByHandle(lhs).ValueOrDie();
1720           HloOpcode lhs_opcode =
1721               StringToHloOpcode(lhs_inst->opcode()).ValueOrDie();
1722           if (lhs_opcode == HloOpcode::kAdd) {
1723             auto add_lhs_handle = SimplifyOp(lhs_inst->operand_ids(0))
1724                                       .ValueOrDie()
1725                                       .Get<int64_t>({});
1726             auto add_rhs_handle = SimplifyOp(lhs_inst->operand_ids(1))
1727                                       .ValueOrDie()
1728                                       .Get<int64_t>({});
1729             if (auto optimized = can_be_optimized(add_lhs_handle, rhs)) {
1730               return Add(XlaOp(add_rhs_handle, builder_),
1731                          XlaOp(optimized.value(), builder_))
1732                   .handle();
1733             }
1734             if (auto optimized = can_be_optimized(add_rhs_handle, rhs)) {
1735               return Add(XlaOp(add_lhs_handle, builder_),
1736                          XlaOp(optimized.value(), builder_))
1737                   .handle();
1738             }
1739           }
1740           return std::nullopt;
1741         };
1742         if (auto optimized = can_be_optimized(lhs_handle, rhs_handle)) {
1743           return LiteralUtil::CreateR0<int64_t>(optimized.value());
1744         }
1745         // Swap lhs and rhs.
1746         if (auto optimized = can_be_optimized(rhs_handle, lhs_handle)) {
1747           return LiteralUtil::CreateR0<int64_t>(optimized.value());
1748         }
1749         // This sum can't be optimized, return sum of lhs and rhs. Note that we
1750         // can't just return the original sum as its lhs and rhs could be
1751         // optimized and different.
1752         XlaOp new_sum =
1753             Add(XlaOp(lhs_handle, builder_), XlaOp(rhs_handle, builder_));
1754 
1755         return LiteralUtil::CreateR0<int64_t>(new_sum.handle());
1756       } else {
1757         return CreateS64Literal(-1, output_shape);
1758       }
1759     }
1760     default: {
1761       if (ShapeUtil::IsScalar(output_shape)) {
1762         return LiteralUtil::CreateR0<int64_t>(handle);
1763       } else {
1764         return CreateS64Literal(-1, output_shape);
1765       }
1766     }
1767   }
1768 }
1769 
AnalyzeConstant(XlaOp op,ValueInferenceMode mode)1770 StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
1771     XlaOp op, ValueInferenceMode mode) {
1772   TF_RETURN_IF_ERROR(builder_->LookUpInstructionByHandle(op.handle()).status());
1773   PostorderDFSVisitor visitor(
1774       evaluator_,
1775       [&](int64_t handle) {
1776         return builder_->LookUpInstructionByHandle(handle);
1777       },
1778       [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1779   TF_ASSIGN_OR_RETURN(Shape op_shape, builder_->GetShape(op));
1780   int64_t handle = op.handle();
1781   if (ShapeUtil::IsScalar(builder_->GetShape(op).ValueOrDie())) {
1782     TF_ASSIGN_OR_RETURN(auto result, SimplifyOp(handle));
1783     auto optimized_handle = result.Get<int64_t>({});
1784     if (optimized_handle != -1) {
1785       handle = optimized_handle;
1786     }
1787   }
1788   switch (mode) {
1789     case ValueInferenceMode::kLowerBound: {
1790       TF_ASSIGN_OR_RETURN(Literal mask,
1791                           visitor.PostOrderDFSVisit(
1792                               handle, PostorderDFSNodeType::kBoundIsDynamic));
1793       if (mask.IsAll(1)) {
1794         // Everything is dynamic, no need to do constant inference.
1795         return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
1796       }
1797       TF_ASSIGN_OR_RETURN(
1798           Literal value,
1799           visitor.PostOrderDFSVisit(handle,
1800                                     PostorderDFSNodeType::kConstantLowerBound));
1801 
1802       return OptionalLiteral(std::move(value), std::move(mask));
1803     }
1804     case ValueInferenceMode::kUpperBound: {
1805       TF_ASSIGN_OR_RETURN(Literal mask,
1806                           visitor.PostOrderDFSVisit(
1807                               handle, PostorderDFSNodeType::kBoundIsDynamic));
1808       if (mask.IsAll(1)) {
1809         // Everything is dynamic, no need to do constant inference.
1810         return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
1811       }
1812       TF_ASSIGN_OR_RETURN(
1813           Literal value,
1814           visitor.PostOrderDFSVisit(handle,
1815                                     PostorderDFSNodeType::kConstantUpperBound));
1816 
1817       return OptionalLiteral(std::move(value), std::move(mask));
1818     }
1819     case ValueInferenceMode::kValue: {
1820       TF_ASSIGN_OR_RETURN(Literal mask,
1821                           visitor.PostOrderDFSVisit(
1822                               handle, PostorderDFSNodeType::kValueIsDynamic));
1823       if (mask.IsAll(1)) {
1824         // Everything is dynamic, no need to do constant inference.
1825         return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
1826       }
1827       TF_ASSIGN_OR_RETURN(Literal value,
1828                           visitor.PostOrderDFSVisit(
1829                               handle, PostorderDFSNodeType::kConstantValue));
1830 
1831       return OptionalLiteral(std::move(value), std::move(mask));
1832     }
1833   }
1834 }
1835 
1836 }  // namespace xla
1837