• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/value_inference.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/comparison_util.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/primitive_util.h"
29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
30 #include "tensorflow/compiler/xla/service/hlo.pb.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/stream_executor/lib/statusor.h"
40 
41 namespace xla {
42 namespace {
CreatePredLiteral(bool pred,const Shape & reference_shape)43 Literal CreatePredLiteral(bool pred, const Shape& reference_shape) {
44   if (reference_shape.IsTuple()) {
45     std::vector<Literal> sub_literals;
46     for (const Shape& shape : reference_shape.tuple_shapes()) {
47       sub_literals.emplace_back(CreatePredLiteral(pred, shape));
48     }
49     return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
50   }
51   PrimitiveType element_type = reference_shape.element_type();
52   if (element_type == TOKEN) {
53     return LiteralUtil::CreateR0(pred);
54   }
55   Literal literal = LiteralUtil::CreateR0(pred);
56   Literal literal_broadcast =
57       literal.Broadcast(ShapeUtil::ChangeElementType(reference_shape, PRED), {})
58           .ValueOrDie();
59   return literal_broadcast;
60 }
61 
CreateS64Literal(int64_t value,const Shape & reference_shape)62 Literal CreateS64Literal(int64_t value, const Shape& reference_shape) {
63   if (reference_shape.IsTuple()) {
64     std::vector<Literal> sub_literals;
65     for (const Shape& shape : reference_shape.tuple_shapes()) {
66       sub_literals.emplace_back(CreateS64Literal(value, shape));
67     }
68     return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
69   }
70   PrimitiveType element_type = reference_shape.element_type();
71   if (element_type == TOKEN) {
72     return LiteralUtil::CreateToken();
73   }
74   Literal literal = LiteralUtil::CreateR0<int64>(value);
75   return literal
76       .Broadcast(ShapeUtil::ChangeElementType(reference_shape, S64), {})
77       .ValueOrDie();
78 }
79 
80 // Create a literal with garbage data. The data inside is undefined and
81 // shouldn't be used in any meaningful computation.
CreateGarbageLiteral(const Shape & reference_shape)82 Literal CreateGarbageLiteral(const Shape& reference_shape) {
83   if (reference_shape.IsTuple()) {
84     std::vector<Literal> sub_literals;
85     for (const Shape& shape : reference_shape.tuple_shapes()) {
86       sub_literals.emplace_back(CreateGarbageLiteral(shape));
87     }
88     return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
89   }
90   PrimitiveType element_type = reference_shape.element_type();
91   if (element_type == TOKEN) {
92     return LiteralUtil::CreateToken();
93   }
94   Literal literal = LiteralUtil::One(element_type);
95   return literal.Broadcast(reference_shape, {}).ValueOrDie();
96 }
97 
98 // HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
99 // to provide operand as literals through the get_operand function.
100 struct HloProtoEvaluator {
HloProtoEvaluatorxla::__anon68edb0ec0111::HloProtoEvaluator101   explicit HloProtoEvaluator(HloInstructionProto inst)
102       : inst(std::move(inst)),
103         module("EmptyModuleForEvaluation", HloModuleConfig()) {}
104 
105   // WithOpCode changes the called computation of the instruction being
106   // evaluated.
WithComputationxla::__anon68edb0ec0111::HloProtoEvaluator107   HloProtoEvaluator& WithComputation(
108       std::unique_ptr<HloComputation> new_computation) {
109     computation = new_computation.get();
110     computation->ClearUniqueIdInternal();
111     for (HloInstruction* inst : computation->instructions()) {
112       inst->ClearUniqueIdInternal();
113     }
114     module.AddEmbeddedComputation(std::move(new_computation));
115     return *this;
116   }
117 
118   // WithPrimitiveType changes the primitive type of the instruction being
119   // evaluated.
WithPrimitiveTypexla::__anon68edb0ec0111::HloProtoEvaluator120   HloProtoEvaluator& WithPrimitiveType(PrimitiveType new_primitive_type) {
121     primitive_type = new_primitive_type;
122     return *this;
123   }
124 
125   // WithOpCode changes the opcode of the instruction being evaluated.
WithOpCodexla::__anon68edb0ec0111::HloProtoEvaluator126   HloProtoEvaluator& WithOpCode(HloOpcode new_opcode) {
127     opcode = new_opcode;
128     return *this;
129   }
130 
131   // WithOperands changes the operands of the instruction being evaluated.
WithOperandsxla::__anon68edb0ec0111::HloProtoEvaluator132   HloProtoEvaluator& WithOperands(absl::Span<Literal> operands) {
133     this->operands = operands;
134     return *this;
135   }
136 
137   // When WithSubshape is set, the result tuple shape will be decomposed and
138   // specific the literal will be returned.
WithSubshapexla::__anon68edb0ec0111::HloProtoEvaluator139   HloProtoEvaluator& WithSubshape(ShapeIndex shape_index) {
140     this->shape_index = std::move(shape_index);
141     return *this;
142   }
143 
Evaluatexla::__anon68edb0ec0111::HloProtoEvaluator144   StatusOr<Literal> Evaluate() {
145     // Evaluate the instruction by swapping it's operands with constant
146     // instructions with given literals.
147     HloComputation::Builder builder("EmptyComputation");
148     absl::flat_hash_map<int64, HloInstruction*> operand_map;
149     for (int64_t i = 0; i < inst.operand_ids_size(); ++i) {
150       int64_t operand_handle = inst.operand_ids(i);
151       std::unique_ptr<HloInstruction> operand =
152           HloInstruction::CreateConstant(operands[i].Clone());
153       operand_map[operand_handle] = operand.get();
154       builder.AddInstruction(std::move(operand));
155     }
156 
157     if (primitive_type.has_value()) {
158       *inst.mutable_shape() = ShapeUtil::ChangeElementType(
159                                   Shape(inst.shape()), primitive_type.value())
160                                   .ToProto();
161     }
162     if (opcode.has_value()) {
163       *inst.mutable_opcode() = HloOpcodeString(opcode.value());
164     }
165     absl::flat_hash_map<int64, HloComputation*> computation_map;
166     if (inst.called_computation_ids_size() != 0) {
167       TF_RET_CHECK(inst.called_computation_ids_size() == 1 &&
168                    computation != nullptr)
169           << inst.DebugString();
170       computation_map[inst.called_computation_ids(0)] = computation;
171     }
172     TF_ASSIGN_OR_RETURN(
173         auto new_instruction,
174         HloInstruction::CreateFromProto(inst, operand_map, computation_map));
175     new_instruction->ClearUniqueIdInternal();
176     builder.AddInstruction(std::move(new_instruction));
177     auto computation = builder.Build();
178     module.AddEntryComputation(std::move(computation));
179     HloEvaluator evaluator;
180     if (shape_index.empty()) {
181       return evaluator.Evaluate(module.entry_computation()->root_instruction());
182     } else {
183       TF_ASSIGN_OR_RETURN(
184           auto result,
185           evaluator.Evaluate(module.entry_computation()->root_instruction()));
186       return result.SubLiteral(this->shape_index);
187     }
188   }
189 
190   HloInstructionProto inst;
191 
192   HloModule module;
193   absl::Span<Literal> operands;
194   ShapeIndex shape_index = {};
195   HloComputation* computation = nullptr;
196   absl::optional<PrimitiveType> primitive_type = absl::nullopt;
197   absl::optional<HloOpcode> opcode = absl::nullopt;
198 };
199 
200 enum PostorderDFSNodeType {
201   // This node is about figuring out the constant value.
202   kConstantValue = 0,
203   // This node is about figuring out the constant bound.
204   kConstantUpperBound,
205   kConstantLowerBound,
206   // This node is about figuring out whether a value is dynamic.
207   kValueIsDynamic,
208   // This node is about figuring out whether a bound value is dynamic. It's
209   // similar to kValueIsDynamic, but views shape bound as static values.
210   kBoundIsDynamic,
211 };
212 
PostorderDFSNodeTypeToString(PostorderDFSNodeType type)213 std::string PostorderDFSNodeTypeToString(PostorderDFSNodeType type) {
214   switch (type) {
215     case kConstantValue:
216       return "kConstantValue";
217     case kConstantUpperBound:
218       return "kConstantUpperBound";
219     case kConstantLowerBound:
220       return "kConstantLowerBound";
221     case kValueIsDynamic:
222       return "kValueIsDynamic";
223     case kBoundIsDynamic:
224       return "kBoundIsDynamic";
225   }
226 }
227 
228 struct InferenceContext {
InferenceContextxla::__anon68edb0ec0111::InferenceContext229   explicit InferenceContext(ShapeIndex shape_index,
230                             std::vector<int64> caller_operand_handles)
231       : shape_index(std::move(shape_index)),
232         caller_operand_handles(std::move(caller_operand_handles)) {}
233   // `shape_index` represents the subshape that we care about in the inference.
234   // It is used to avoid meterializing the whole tuple when we only care about a
235   // sub tensor of it.
236   ShapeIndex shape_index;
237 
238   // caller_operand_handles is a stack that helps argument forwarding. The top
239   // of the stack represents the tensor to be forwarded to the
240   // parameter of the inner most function. E.g.,:
241   // inner_true_computation {
242   //   inner_p0 = param()
243   //   ...
244   // }
245   //
246   // true_computaion {
247   //   p0 = param()
248   //   conditional(pred, p0, inner_true_computation,
249   //                     ...)
250   // }
251   //
252   // main {
253   //   op = ..
254   //   conditional(pred, op, true_computation, ...)
255   // }
256   //
257   // In this case, when we analyze inner_true_computation, the
258   // `caller_operand_handlers` will be [op, p0] -- p0 is what should be
259   // forwarded to inner_p0 and op is what should be forwarded to p0. similarly,
260   // when we analyze true_computation, the `caller_operand_handlers` will be
261   // [op].
262   std::vector<int64> caller_operand_handles;
263 };
264 
265 // Each node in the postorder traversal tree may depend on traversing the
266 // values of the node's children.
267 struct PostorderDFSDep {
PostorderDFSDepxla::__anon68edb0ec0111::PostorderDFSDep268   explicit PostorderDFSDep(int64_t handle, PostorderDFSNodeType type,
269                            InferenceContext context, std::string annotation)
270       : handle(handle),
271         type(type),
272         context(std::move(context)),
273         annotation(std::move(annotation)) {}
274   int64 handle;
275   PostorderDFSNodeType type;
276   InferenceContext context;
277   std::string annotation;
278 };
279 
280 // This function represents the logic to visit a node once its dependencies
281 // (operands) are all resolved.
282 using Visit = std::function<StatusOr<Literal>(absl::Span<Literal>)>;
283 // Convenient specializations of Visit function for different operands.
284 using Visit0D = std::function<StatusOr<Literal>()>;
285 using Visit1D = std::function<StatusOr<Literal>(Literal)>;
286 using Visit2D = std::function<StatusOr<Literal>(Literal, Literal)>;
287 
288 // A postorder dfs node can be visited once its dependency requests are all
289 // fulfilled.
290 struct TF_MUST_USE_RESULT PostorderDFSNode {
AddDependencyxla::__anon68edb0ec0111::PostorderDFSNode291   PostorderDFSNode& AddDependency(int64_t handle, PostorderDFSNodeType type,
292                                   InferenceContext context,
293                                   std::string annotation = "") {
294     dependencies.emplace_back(handle, type, std::move(context),
295                               std::move(annotation));
296     return *this;
297   }
298 
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode299   PostorderDFSNode& AddVisit(const Visit& visit) {
300     this->visit = visit;
301     return *this;
302   }
303 
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode304   PostorderDFSNode& AddVisit(const Visit0D& visit) {
305     this->visit = [visit](absl::Span<Literal> literals) { return visit(); };
306     return *this;
307   }
308 
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode309   PostorderDFSNode& AddVisit(const Visit1D& visit) {
310     this->visit = [visit](absl::Span<Literal> literals) {
311       return visit(std::move(literals[0]));
312     };
313     return *this;
314   }
315 
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode316   PostorderDFSNode& AddVisit(const Visit2D& visit) {
317     this->visit = [visit](absl::Span<Literal> literals) {
318       return visit(std::move(literals[0]), std::move(literals[1]));
319     };
320     return *this;
321   }
322   std::vector<PostorderDFSDep> dependencies;
323   Visit visit;
324 };
325 
326 // Convert an interger handle to HloInstructionProto.
327 using HandleToInstruction = std::function<const HloInstructionProto*(int64_t)>;
328 using HandleToComputation = std::function<const HloComputationProto*(int64_t)>;
329 
330 struct PostorderDFSVisitor {
PostorderDFSVisitorxla::__anon68edb0ec0111::PostorderDFSVisitor331   PostorderDFSVisitor(HandleToInstruction handle_to_instruction,
332                       HandleToComputation handle_to_computation)
333       : handle_to_instruction(handle_to_instruction),
334         handle_to_computation(handle_to_computation) {}
335 
336   StatusOr<PostorderDFSNode> AnalyzeUpperBound(int64_t handle,
337                                                InferenceContext context);
338   StatusOr<PostorderDFSNode> AnalyzeLowerBound(int64_t handle,
339                                                InferenceContext context);
340   StatusOr<PostorderDFSNode> AnalyzeIsDynamic(int64_t handle,
341                                               PostorderDFSNodeType type,
342                                               InferenceContext context);
343   StatusOr<PostorderDFSNode> AnalyzeConstant(int64_t handle,
344                                              InferenceContext context);
345   StatusOr<PostorderDFSNode> AnalyzeConstantValueFallback(
346       int64_t handle, PostorderDFSNodeType type, InferenceContext context);
347 
348   StatusOr<Literal> PostOrderDFSVisit(int64_t handle,
349                                       PostorderDFSNodeType type);
350 
351   // Returns true if a value represented by `handle` is an integeral type or
352   // a floating pointer type that just got converted from an integral type.
353   // E.g.,:
354   // 1 -> true
355   // float(1) -> true
356   // 1.1 -> false
357   // 1.0 -> false -- We don't know the concrete value at runtime, except for its
358   // type.
IsValueEffectiveIntegerxla::__anon68edb0ec0111::PostorderDFSVisitor359   bool IsValueEffectiveInteger(int64_t handle) {
360     const HloInstructionProto* instr = handle_to_instruction(handle);
361     if (primitive_util::IsIntegralType(instr->shape().element_type())) {
362       return true;
363     }
364     // Also returns true if this is a convert that converts an integer to float.
365     HloOpcode opcode = StringToHloOpcode(instr->opcode()).ValueOrDie();
366     if (opcode != HloOpcode::kConvert) {
367       return false;
368     }
369     const HloInstructionProto* parent =
370         handle_to_instruction(instr->operand_ids(0));
371     if (primitive_util::IsIntegralType(parent->shape().element_type())) {
372       return true;
373     }
374     return false;
375   }
376 
377   struct CacheKey {
CacheKeyxla::__anon68edb0ec0111::PostorderDFSVisitor::CacheKey378     CacheKey(int64 handle, InferenceContext context, PostorderDFSNodeType type)
379         : handle(handle), context(context), type(type) {}
380     int64 handle;
381     InferenceContext context;
382     PostorderDFSNodeType type;
383 
384     template <typename H>
AbslHashValuexla::__anon68edb0ec0111::PostorderDFSVisitor385     friend H AbslHashValue(H h, const CacheKey& key) {
386       h = H::combine(std::move(h), key.handle);
387       h = H::combine(std::move(h), key.context.shape_index.ToString());
388       h = H::combine(std::move(h),
389                      VectorString(key.context.caller_operand_handles));
390       h = H::combine(std::move(h), key.type);
391       return h;
392     }
393 
operator ==xla::__anon68edb0ec0111::PostorderDFSVisitor394     friend bool operator==(const CacheKey& lhs, const CacheKey& rhs) {
395       return lhs.handle == rhs.handle &&
396              lhs.context.shape_index == rhs.context.shape_index &&
397              lhs.context.caller_operand_handles ==
398                  rhs.context.caller_operand_handles &&
399              lhs.type == rhs.type;
400     }
401   };
402 
403   absl::flat_hash_map<CacheKey, Literal> evaluated;
404   HandleToInstruction handle_to_instruction;
405   HandleToComputation handle_to_computation;
406 };
407 
408 }  // namespace
409 
410 // Analyze a tensor's constant value, upper-bound value or lower-bound value.
AnalyzeConstantValueFallback(int64_t handle,PostorderDFSNodeType type,InferenceContext context)411 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstantValueFallback(
412     int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
413   const HloInstructionProto* root = handle_to_instruction(handle);
414   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
415   PostorderDFSNode result;
416   // By default, the dependencies of current node are its operands.
417   for (auto operand_id : root->operand_ids()) {
418     InferenceContext dep_context = context;
419     dep_context.shape_index = {};
420     result.AddDependency(operand_id, type, dep_context);
421   }
422   switch (opcode) {
423       // Non functional ops.
424     case HloOpcode::kRng:
425     case HloOpcode::kAllReduce:
426     case HloOpcode::kReduceScatter:
427     case HloOpcode::kInfeed:
428     case HloOpcode::kOutfeed:
429     case HloOpcode::kCall:
430     case HloOpcode::kCustomCall:
431     case HloOpcode::kWhile:
432     case HloOpcode::kSend:
433     case HloOpcode::kRecv:
434     case HloOpcode::kSendDone:
435     case HloOpcode::kRecvDone:
436     case HloOpcode::kParameter: {
437       if (opcode == HloOpcode::kParameter &&
438           !context.caller_operand_handles.empty()) {
439         int64_t caller_operand = context.caller_operand_handles.back();
440         context.caller_operand_handles.pop_back();
441         return result.AddDependency(caller_operand, type, context)
442             .AddVisit([](Literal literal) { return literal; });
443       }
444       return PostorderDFSNode().AddVisit([root, context](absl::Span<Literal>) {
445         // The value is dynamic. We return a garbage literal here, which
446         // will be masked out later.
447         return CreateGarbageLiteral(
448             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
449       });
450     }
451     // Subtract and Divide use lower-bound as second operand.
452     case HloOpcode::kSubtract:
453     case HloOpcode::kCos:
454     case HloOpcode::kSin:
455     case HloOpcode::kNegate:
456     case HloOpcode::kAbs:
457     case HloOpcode::kDivide:
458     case HloOpcode::kGetDimensionSize: {
459       return InvalidArgument(
460           "AnalyzeConstantValueFallback can't handle opcode: %s",
461           root->opcode());
462     }
463 
464     case HloOpcode::kConditional: {
465       auto node = PostorderDFSNode();
466       auto* conditional_proto = root;
467       InferenceContext predicate_context = context;
468       predicate_context.shape_index = {};
469       // Add dependencies to analyze the predicate of the conditional.
470       node.AddDependency(conditional_proto->operand_ids(0),
471                          PostorderDFSNodeType::kConstantValue,
472                          predicate_context)
473           .AddDependency(conditional_proto->operand_ids(0),
474                          PostorderDFSNodeType::kValueIsDynamic,
475                          predicate_context);
476       const int64_t branch_size =
477           conditional_proto->called_computation_ids_size();
478       for (int64_t i = 0; i < branch_size; ++i) {
479         int64_t branch_root =
480             handle_to_computation(conditional_proto->called_computation_ids(i))
481                 ->root_id();
482         InferenceContext branch_context = context;
483         branch_context.caller_operand_handles.push_back(
484             conditional_proto->operand_ids(i + 1));
485         node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
486                            branch_context);
487       }
488       return node.AddVisit(
489           [](absl::Span<Literal> operands) -> StatusOr<Literal> {
490             int64_t pred_is_dynamic = operands[1].Get<bool>({});
491             if (pred_is_dynamic) {
492               // If predicate is dynamic, return the value of the first branch
493               // -- If all branches return the same value, this is the value
494               // that we want; If not, the value will be masked anyway so the
495               // value inside doesn't matter.
496               return std::move(operands[2]);
497             } else {
498               // If predicate is static, return the value of the given branch.
499               int64_t branch_index = 0;
500               if (operands[0].shape().element_type() == PRED) {
501                 if (operands[0].Get<bool>({})) {
502                   branch_index = 0;
503                 } else {
504                   branch_index = 1;
505                 }
506               } else {
507                 branch_index = operands[0].GetIntegralAsS64({}).value();
508               }
509               const int64_t branch_dynamism_index = 2 + branch_index;
510               return std::move(operands[branch_dynamism_index]);
511             }
512           });
513     }
514     case HloOpcode::kGetTupleElement: {
515       int64_t operand_handle = root->operand_ids(0);
516       PostorderDFSNode result;
517       context.shape_index.push_front(root->tuple_index());
518       return PostorderDFSNode()
519           .AddDependency(operand_handle, type, context)
520           .AddVisit([](Literal operand) { return operand; });
521     }
522     case HloOpcode::kReduce:
523     case HloOpcode::kScatter:
524     case HloOpcode::kReduceWindow: {
525       const HloComputationProto* computation_proto =
526           handle_to_computation(root->called_computation_ids(0));
527       return result.AddVisit(
528           [root, computation_proto,
529            context](absl::Span<Literal> operands) -> StatusOr<Literal> {
530             TF_ASSIGN_OR_RETURN(
531                 auto computation,
532                 HloComputation::CreateFromProto(*computation_proto, {}));
533             return HloProtoEvaluator(*root)
534                 .WithOperands(operands)
535                 .WithComputation(std::move(computation))
536                 .WithSubshape(context.shape_index)
537                 .Evaluate();
538           });
539     }
540     default: {
541       if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
542         // There could be many operands of a tuple, but only one that we are
543         // interested in, represented by `tuple_operand_index`.
544         int64_t tuple_operand_index = context.shape_index.front();
545         InferenceContext tuple_operand_context = context;
546         tuple_operand_context.shape_index.pop_front();
547         return PostorderDFSNode()
548             .AddDependency(root->operand_ids(tuple_operand_index), type,
549                            tuple_operand_context)
550             .AddVisit([](Literal operand) { return operand; });
551       }
552       return result.AddVisit([root](absl::Span<Literal> operands) {
553         return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
554       });
555     }
556   }
557 }
558 
AnalyzeUpperBound(int64_t handle,InferenceContext context)559 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeUpperBound(
560     int64_t handle, InferenceContext context) {
561   const HloInstructionProto* root = handle_to_instruction(handle);
562   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
563   switch (opcode) {
564     case HloOpcode::kGetDimensionSize: {
565       int64_t dimension = root->dimensions(0);
566       int64_t operand_handle = root->operand_ids(0);
567       const HloInstructionProto* operand_proto =
568           handle_to_instruction(operand_handle);
569       return PostorderDFSNode().AddVisit(
570           [operand_proto, dimension]() -> StatusOr<Literal> {
571             return LiteralUtil::CreateR0<int32>(
572                 operand_proto->shape().dimensions(dimension));
573           });
574     }
575     case HloOpcode::kAbs: {
576       // upper-bound(abs(operand)) = max(abs(lower-bound(operand)),
577       //                                 abs(upper-bound(operand)))
578       return PostorderDFSNode()
579           .AddDependency(root->operand_ids(0),
580                          PostorderDFSNodeType::kConstantLowerBound, context)
581           .AddDependency(root->operand_ids(0),
582                          PostorderDFSNodeType::kConstantUpperBound, context)
583           .AddVisit([](Literal lower_bound,
584                        Literal upper_bound) -> StatusOr<Literal> {
585             HloEvaluator evaluator;
586             TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
587                                 evaluator.EvaluateElementwiseUnaryOp(
588                                     HloOpcode::kAbs, lower_bound));
589             TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
590                                 evaluator.EvaluateElementwiseUnaryOp(
591                                     HloOpcode::kAbs, upper_bound));
592             return evaluator.EvaluateElementwiseBinaryOp(
593                 HloOpcode::kMaximum, lower_bound_abs, upper_bound_abs);
594           });
595     }
596     case HloOpcode::kSort: {
597       auto dfs = PostorderDFSNode();
598       InferenceContext dep_context = context;
599       dep_context.shape_index = {};
600       if (!context.shape_index.empty()) {
601         // Lazy evaluation: Only need to evaluate a subelement in a
602         // variadic-sort tensor.
603         dfs.AddDependency(root->operand_ids(context.shape_index[0]),
604                           PostorderDFSNodeType::kConstantUpperBound,
605                           dep_context);
606       } else {
607         for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
608           dfs.AddDependency(root->operand_ids(i),
609                             PostorderDFSNodeType::kConstantUpperBound,
610                             dep_context);
611         }
612       }
613 
614       return dfs.AddVisit(
615           [root, context](absl::Span<Literal> operands) -> StatusOr<Literal> {
616             std::vector<Literal> results;
617             results.reserve(operands.size());
618             // Conservatively set each element of the tensor to the max value.
619             for (int64_t i = 0; i < operands.size(); ++i) {
620               auto max = LiteralUtil::MaxElement(operands[i]);
621               results.emplace_back(
622                   max.Broadcast(operands[i].shape(), {}).ValueOrDie());
623             }
624             if (ShapeUtil::GetSubshape(Shape(root->shape()),
625                                        context.shape_index)
626                     .IsTuple()) {
627               return LiteralUtil::MakeTupleOwned(std::move(results));
628             } else {
629               return std::move(results[0]);
630             }
631           });
632     }
633     case HloOpcode::kNegate: {
634       // upper-bound(negate(operand)) = negate(lower-bound(operand))
635       return PostorderDFSNode()
636           .AddDependency(root->operand_ids(0),
637                          PostorderDFSNodeType::kConstantLowerBound, context)
638           .AddVisit([](Literal lower_bound) -> StatusOr<Literal> {
639             HloEvaluator evaluator;
640             return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
641                                                         lower_bound);
642           });
643     }
644     case HloOpcode::kSubtract:
645     case HloOpcode::kDivide: {
646       // Lower-bound is used for second operand of subtract and divide.
647       return PostorderDFSNode()
648           .AddDependency(root->operand_ids(0),
649                          PostorderDFSNodeType::kConstantUpperBound, context)
650           .AddDependency(root->operand_ids(1),
651                          PostorderDFSNodeType::kConstantLowerBound, context)
652           .AddVisit([root, opcode, this](
653                         Literal upper_bound,
654                         Literal lower_bound) -> StatusOr<Literal> {
655             if (opcode == HloOpcode::kDivide &&
656                 this->IsValueEffectiveInteger(root->operand_ids(1))) {
657               // Because in many cases the lower bound of a value is
658               // integer 0, instead of throwing an divide-by-zero error
659               // at compile time, we set the bound defer the check to
660               // runtime. In those cases we use the upper-bound of
661               // first operand as a placeholder.
662               HloEvaluator evaluator;
663               auto zero = LiteralUtil::Zero(lower_bound.shape().element_type());
664               zero = zero.Broadcast(lower_bound.shape(), {}).ValueOrDie();
665               TF_ASSIGN_OR_RETURN(
666                   auto lower_bound_is_zero,
667                   evaluator.EvaluateElementwiseCompareOp(
668                       ComparisonDirection::kEq, lower_bound, zero));
669 
670               auto one = LiteralUtil::One(lower_bound.shape().element_type());
671               one = one.Broadcast(lower_bound.shape(), {}).ValueOrDie();
672               TF_ASSIGN_OR_RETURN(
673                   lower_bound, evaluator.EvaluateElementwiseTernaryOp(
674                                    HloOpcode::kSelect, lower_bound_is_zero, one,
675                                    lower_bound));
676             }
677             std::vector<Literal> new_operands;
678             new_operands.emplace_back(std::move(upper_bound));
679             new_operands.emplace_back(std::move(lower_bound));
680             return HloProtoEvaluator(*root)
681                 .WithOperands(absl::MakeSpan(new_operands))
682                 .Evaluate();
683           });
684     }
685     case HloOpcode::kCustomCall: {
686       if (root->custom_call_target() == "SetBound") {
687         return PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
688           if (root->literal().shape().element_type() == TUPLE) {
689             // First literal of SetBound contains bounds, second literal
690             // contains dynamism indicators.
691             return Literal::CreateFromProto(root->literal().tuple_literals(0));
692           } else {
693             return Literal::CreateFromProto(root->literal());
694           }
695         });
696       } else if (root->custom_call_target() == "Sharding") {
697         return PostorderDFSNode()
698             .AddDependency(root->operand_ids(0),
699                            PostorderDFSNodeType::kConstantUpperBound, context)
700             .AddVisit([](Literal operand) { return operand; });
701       }
702       return InvalidArgument(
703           "Upper-bound inferencing on custom call %s is not supported",
704           root->DebugString());
705     }
706     default:
707       return AnalyzeConstantValueFallback(
708           handle, PostorderDFSNodeType::kConstantUpperBound, context);
709   }
710 }
711 
AnalyzeLowerBound(int64_t handle,InferenceContext context)712 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeLowerBound(
713     int64_t handle, InferenceContext context) {
714   const HloInstructionProto* root = handle_to_instruction(handle);
715   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
716   switch (opcode) {
717     case HloOpcode::kGetDimensionSize: {
718       int64_t dimension = root->dimensions(0);
719       int64_t operand_handle = root->operand_ids(0);
720       const HloInstructionProto* operand_proto =
721           handle_to_instruction(operand_handle);
722       return PostorderDFSNode().AddVisit(
723           [dimension, operand_proto]() -> StatusOr<Literal> {
724             if (operand_proto->shape().is_dynamic_dimension(dimension)) {
725               return LiteralUtil::CreateR0<int32>(0);
726             } else {
727               return LiteralUtil::CreateR0<int32>(
728                   operand_proto->shape().dimensions(dimension));
729             }
730           });
731     }
732     case HloOpcode::kAbs: {
733       // lower-bound(abs(operand)) = min(abs(lower-bound(operand)),
734       // abs(upper-bound(operand)))
735       return PostorderDFSNode()
736           .AddDependency(root->operand_ids(0),
737                          PostorderDFSNodeType::kConstantLowerBound, context)
738           .AddDependency(root->operand_ids(0),
739                          PostorderDFSNodeType::kConstantUpperBound, context)
740           .AddVisit([](Literal lower_bound,
741                        Literal upper_bound) -> StatusOr<Literal> {
742             HloEvaluator evaluator;
743             TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
744                                 evaluator.EvaluateElementwiseUnaryOp(
745                                     HloOpcode::kAbs, lower_bound));
746             TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
747                                 evaluator.EvaluateElementwiseUnaryOp(
748                                     HloOpcode::kAbs, upper_bound));
749             return evaluator.EvaluateElementwiseBinaryOp(
750                 HloOpcode::kMinimum, lower_bound_abs, upper_bound_abs);
751           });
752     }
753     case HloOpcode::kNegate: {
754       // lower-bound(negate(operand)) = negate(upper-bound(operand))
755       return PostorderDFSNode()
756           .AddDependency(root->operand_ids(0),
757                          PostorderDFSNodeType::kConstantUpperBound, context)
758           .AddVisit([](Literal upper_bound) -> StatusOr<Literal> {
759             HloEvaluator evaluator;
760             return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
761                                                         upper_bound);
762           });
763     }
764     case HloOpcode::kSubtract:
765     case HloOpcode::kDivide: {
766       // Upper bound is used for second operand of subtract and divide.
767       return PostorderDFSNode()
768           .AddDependency(root->operand_ids(0),
769                          PostorderDFSNodeType::kConstantLowerBound, context)
770           .AddDependency(root->operand_ids(1),
771                          PostorderDFSNodeType::kConstantUpperBound, context)
772           .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
773             return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
774           });
775     }
776     default:
777       return AnalyzeConstantValueFallback(
778           handle, PostorderDFSNodeType::kConstantLowerBound, context);
779   }
780 }
781 
AnalyzeConstant(int64_t handle,InferenceContext context)782 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstant(
783     int64_t handle, InferenceContext context) {
784   const HloInstructionProto* root = handle_to_instruction(handle);
785   HloOpcode opcode = StringToHloOpcode(root->opcode()).ValueOrDie();
786   switch (opcode) {
787     case HloOpcode::kGetDimensionSize: {
788       int64_t dimension = root->dimensions(0);
789       int64_t operand_handle = root->operand_ids(0);
790       const HloInstructionProto* operand_proto =
791           handle_to_instruction(operand_handle);
792       return PostorderDFSNode().AddVisit(
793           [operand_proto, dimension, root]() -> StatusOr<Literal> {
794             if (operand_proto->shape().is_dynamic_dimension(dimension)) {
795               // The value is dynamic, we return garbage data here and mask them
796               // out later.
797               return CreateGarbageLiteral(Shape(root->shape()));
798             } else {
799               return LiteralUtil::CreateR0<int32>(
800                   operand_proto->shape().dimensions(dimension));
801             }
802           });
803     }
804     case HloOpcode::kSubtract:
805     case HloOpcode::kCos:
806     case HloOpcode::kSin:
807     case HloOpcode::kNegate:
808     case HloOpcode::kAbs:
809     case HloOpcode::kDivide: {
810       PostorderDFSNode result;
811       for (auto operand_id : root->operand_ids()) {
812         result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
813                              context);
814       }
815       return result.AddVisit(
816           [root](absl::Span<Literal> operands) -> StatusOr<Literal> {
817             return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
818           });
819     }
820     case HloOpcode::kCustomCall: {
821       if (root->custom_call_target() == "SetBound") {
822         // `SetBound` doesn't change the static value of a tensor, so forward
823         // the operand when analyzing static value.
824         return PostorderDFSNode()
825             .AddDependency(root->operand_ids(0),
826                            PostorderDFSNodeType::kConstantValue, context)
827             .AddVisit(
828                 [](Literal operand) -> StatusOr<Literal> { return operand; });
829       } else if (root->custom_call_target() == "Sharding") {
830         return PostorderDFSNode()
831             .AddDependency(root->operand_ids(0),
832                            PostorderDFSNodeType::kConstantValue, context)
833             .AddVisit([](Literal operand) { return operand; });
834       } else {
835         return PostorderDFSNode().AddVisit(
836             [root, context](absl::Span<Literal>) {
837               // The value is dynamic. We return a garbage literal here, which
838               // will be masked out later.
839               return CreateGarbageLiteral(ShapeUtil::GetSubshape(
840                   Shape(root->shape()), context.shape_index));
841             });
842       }
843     }
844     case HloOpcode::kSort: {
845       PostorderDFSNode result;
846       InferenceContext dep_context = context;
847       dep_context.shape_index = {};
848       for (auto operand_id : root->operand_ids()) {
849         result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
850                              dep_context);
851       }
852       const HloComputationProto* computation_proto =
853           handle_to_computation(root->called_computation_ids(0));
854       return result.AddVisit(
855           [root, context, computation_proto](
856               absl::Span<Literal> operands) -> StatusOr<Literal> {
857             TF_ASSIGN_OR_RETURN(
858                 auto computation,
859                 HloComputation::CreateFromProto(*computation_proto, {}));
860             return HloProtoEvaluator(*root)
861                 .WithOperands(operands)
862                 .WithComputation(std::move(computation))
863                 .WithSubshape(context.shape_index)
864                 .Evaluate();
865           });
866     }
867     default:
868       return AnalyzeConstantValueFallback(
869           handle, PostorderDFSNodeType::kConstantValue, context);
870   }
871 }
872 
AnalyzeIsDynamic(int64_t handle,PostorderDFSNodeType type,InferenceContext context)873 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeIsDynamic(
874     int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
875   const HloInstructionProto* root = handle_to_instruction(handle);
876   // Invariant check.
877   TF_RET_CHECK(root);
878   VLOG(1) << "Analyzing IsDynamic on " << root->DebugString();
879   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
880   PostorderDFSNode result;
881   for (auto operand_id : root->operand_ids()) {
882     InferenceContext dep_context = context;
883     dep_context.shape_index = {};
884     result.AddDependency(operand_id, type, dep_context);
885   }
886   switch (opcode) {
887     case HloOpcode::kGetDimensionSize: {
888       int64_t dimension = root->dimensions(0);
889       int64_t operand_handle = root->operand_ids(0);
890       const HloInstructionProto* operand_proto =
891           handle_to_instruction(operand_handle);
892       return PostorderDFSNode().AddVisit(
893           [operand_proto, dimension, type]() -> StatusOr<Literal> {
894             if (type == PostorderDFSNodeType::kBoundIsDynamic) {
895               // The bound of dynamic dimension is not dynamic.
896               return LiteralUtil::CreateR0<bool>(false);
897             }
898             // The value of dynamic dimension is dynamic.
899             return LiteralUtil::CreateR0<bool>(
900                 operand_proto->shape().is_dynamic_dimension(dimension));
901           });
902     }
903     case HloOpcode::kSort: {
904       auto dfs = PostorderDFSNode();
905       InferenceContext dep_context = context;
906       dep_context.shape_index = {};
907 
908       for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
909         dfs.AddDependency(root->operand_ids(i), type, dep_context);
910       }
911 
912       return dfs.AddVisit([root, context, type](absl::Span<Literal> operands)
913                               -> StatusOr<Literal> {
914         bool all_operands_values_static = true;
915         for (int64_t i = 0; i < operands.size(); ++i) {
916           all_operands_values_static &= operands[i].IsAll(0);
917         }
918         if (type == PostorderDFSNodeType::kValueIsDynamic) {
919           // If there is a single operand of a sort is dynamic, we
920           // conservatively say all results are dynamic.
921           return CreatePredLiteral(!all_operands_values_static,
922                                    ShapeUtil::GetSubshape(Shape(root->shape()),
923                                                           context.shape_index));
924         }
925         CHECK(type == PostorderDFSNodeType::kBoundIsDynamic);
926         // The condition for bounds are more relaxed than values. If we know the
927         // bounds of each element [B0, B1... Bn], all results have the same
928         // bound
929         // [max(B0, B1...), max(B0, B1...), ...]
930         if (!context.shape_index.empty()) {
931           int64_t index = context.shape_index[0];
932           bool all_values_static = operands[index].IsAll(0);
933           return CreatePredLiteral(!all_values_static, operands[index].shape());
934         }
935 
936         std::vector<Literal> results;
937         results.reserve(operands.size());
938         for (int64_t i = 0; i < operands.size(); ++i) {
939           bool all_values_static = operands[i].IsAll(0);
940           results.emplace_back(
941               CreatePredLiteral(!all_values_static, operands[i].shape()));
942         }
943         if (!ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index)
944                  .IsTuple()) {
945           return std::move(results[0]);
946         }
947         return LiteralUtil::MakeTupleOwned(std::move(results));
948       });
949     }
950     case HloOpcode::kSetDimensionSize:
951       return result.AddVisit([root, type](absl::Span<Literal> operands) {
952         bool any_dynamic_operand = absl::c_any_of(
953             operands, [](Literal& operand) { return !operand.IsAll(0); });
954         // If values in a tensor `t` with bound are [e0, e1, e2...], we can say
955         // the max value of each position is [max(t), max(t), max(t), ...]. The
956         // effective size of this tensor doesn't change the max value.
957         return CreatePredLiteral(
958             type == PostorderDFSNodeType::kValueIsDynamic &&
959                 any_dynamic_operand,
960             ShapeUtil::MakeStaticShape(Shape(root->shape())));
961       });
962     case HloOpcode::kDynamicSlice: {
963       return result.AddVisit([root](absl::Span<Literal> operands) {
964         // If any of the operand is dynamic, we say output is dynamic.
965         bool any_dynamic_operand = absl::c_any_of(
966             operands, [](Literal& operand) { return !operand.IsAll(0); });
967         return CreatePredLiteral(any_dynamic_operand, Shape(root->shape()));
968       });
969     }
970     case HloOpcode::kAbs:
971     case HloOpcode::kRoundNearestAfz:
972     case HloOpcode::kBitcast:
973     case HloOpcode::kCeil:
974     case HloOpcode::kCollectivePermuteDone:
975     case HloOpcode::kCos:
976     case HloOpcode::kClz:
977     case HloOpcode::kExp:
978     case HloOpcode::kExpm1:
979     case HloOpcode::kFloor:
980     case HloOpcode::kImag:
981     case HloOpcode::kIsFinite:
982     case HloOpcode::kLog:
983     case HloOpcode::kLog1p:
984     case HloOpcode::kNot:
985     case HloOpcode::kNegate:
986     case HloOpcode::kPopulationCount:
987     case HloOpcode::kReal:
988     case HloOpcode::kRsqrt:
989     case HloOpcode::kLogistic:
990     case HloOpcode::kSign:
991     case HloOpcode::kSin:
992     case HloOpcode::kConvert:
993     case HloOpcode::kSqrt:
994     case HloOpcode::kCbrt:
995     case HloOpcode::kTanh: {
996       // Forward operand as they don't change if a value is dynamic or static.
997       return result.AddVisit([](Literal operand) { return operand; });
998     }
999     case HloOpcode::kAdd:
1000     case HloOpcode::kAtan2:
1001     case HloOpcode::kDivide:
1002     case HloOpcode::kComplex:
1003     case HloOpcode::kMaximum:
1004     case HloOpcode::kMinimum:
1005     case HloOpcode::kMultiply:
1006     case HloOpcode::kPower:
1007     case HloOpcode::kRemainder:
1008     case HloOpcode::kSubtract:
1009     case HloOpcode::kCompare:
1010     case HloOpcode::kAnd:
1011     case HloOpcode::kOr:
1012     case HloOpcode::kXor:
1013     case HloOpcode::kShiftLeft:
1014     case HloOpcode::kShiftRightArithmetic:
1015     case HloOpcode::kShiftRightLogical: {
1016       return result.AddVisit([root](absl::Span<Literal> operands) {
1017         return HloProtoEvaluator(*root)
1018             .WithOperands(operands)
1019             .WithPrimitiveType(PRED)
1020             .WithOpCode(HloOpcode::kOr)
1021             .Evaluate();
1022       });
1023     }
1024     case HloOpcode::kTuple:
1025     case HloOpcode::kTranspose:
1026     case HloOpcode::kSlice:
1027     case HloOpcode::kBroadcast:
1028     case HloOpcode::kReverse:
1029     case HloOpcode::kConcatenate:
1030     case HloOpcode::kReshape:
1031     case HloOpcode::kPad: {
1032       if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
1033         // There could be many operands of a tuple, but only one that we are
1034         // interested in, represented by `tuple_operand_index`.
1035         int64_t tuple_operand_index = context.shape_index.front();
1036         InferenceContext tuple_operand_context = context;
1037         tuple_operand_context.shape_index.pop_front();
1038         return PostorderDFSNode()
1039             .AddDependency(root->operand_ids(tuple_operand_index), type,
1040                            tuple_operand_context)
1041             .AddVisit([](Literal operand) { return operand; });
1042       }
1043       return result.AddVisit([root](absl::Span<Literal> operands) {
1044         return HloProtoEvaluator(*root)
1045             .WithOperands(operands)
1046             .WithPrimitiveType(PRED)
1047             .Evaluate();
1048       });
1049     }
1050     case HloOpcode::kConditional: {
1051       auto node = PostorderDFSNode();
1052       auto* conditional_proto = root;
1053       InferenceContext predicate_context = context;
1054       predicate_context.shape_index = {};
1055       // Add dependencies to analyze the predicate of the conditional.
1056       node.AddDependency(conditional_proto->operand_ids(0),
1057                          PostorderDFSNodeType::kConstantValue,
1058                          predicate_context)
1059           .AddDependency(conditional_proto->operand_ids(0),
1060                          PostorderDFSNodeType::kValueIsDynamic,
1061                          predicate_context);
1062       const int64_t branch_size =
1063           conditional_proto->called_computation_ids_size();
1064       for (int64_t i = 0; i < branch_size; ++i) {
1065         int64_t branch_root =
1066             handle_to_computation(conditional_proto->called_computation_ids(i))
1067                 ->root_id();
1068         InferenceContext branch_context = context;
1069         branch_context.caller_operand_handles.push_back(
1070             conditional_proto->operand_ids(i + 1));
1071         node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
1072                            branch_context,
1073                            absl::StrFormat("branch %lld's value", i))
1074             .AddDependency(branch_root, PostorderDFSNodeType::kValueIsDynamic,
1075                            branch_context,
1076                            absl::StrFormat("branch %lld's dynamism", i));
1077       }
1078       // Predicate uses 2 dependencies:
1079       // 0: Predicate value.
1080       // 1: Predicate is dynamic.
1081       // Each branch i has 2 dependenices:
1082       // 2*i: Branch result value
1083       // 2*i + 1: Branch value is dynamic.
1084       return node.AddVisit(
1085           [root, branch_size,
1086            context](absl::Span<Literal> operands) -> StatusOr<Literal> {
1087             int64_t pred_is_dynamic = operands[1].Get<bool>({});
1088             auto result = CreatePredLiteral(
1089                 true, ShapeUtil::GetSubshape(Shape(root->shape()),
1090                                              context.shape_index));
1091             if (pred_is_dynamic) {
1092               VLOG(1) << "predict is dynamic value" << result.ToString();
1093               // If predicate is dynamic, the result is only static if all
1094               // branches are static and return the same value.
1095               result.MutableEachCell<bool>([&](absl::Span<const int64> indices,
1096                                                bool value) {
1097                 string branch_value = operands[2].GetAsString(indices, {});
1098                 for (int64_t i = 0; i < branch_size; ++i) {
1099                   const int64_t branch_value_index = 2 + 2 * i;
1100                   const int64_t branch_dynamism_index = 2 + 2 * i + 1;
1101                   auto branch_is_dynamic =
1102                       operands[branch_dynamism_index].Get<bool>(indices);
1103                   if (branch_is_dynamic) {
1104                     return true;
1105                   }
1106 
1107                   if (branch_value !=
1108                       operands[branch_value_index].GetAsString(indices, {})) {
1109                     return true;
1110                   }
1111                 }
1112                 // Value of the branch is static.
1113                 return false;
1114               });
1115               return result;
1116             } else {
1117               VLOG(1) << "predict is constant value";
1118               // If predicate is static, return true if given branch result
1119               // value is dynamic.
1120               int64_t branch_index = 0;
1121               if (operands[0].shape().element_type() == PRED) {
1122                 if (operands[0].Get<bool>({})) {
1123                   branch_index = 0;
1124                 } else {
1125                   branch_index = 1;
1126                 }
1127               } else {
1128                 branch_index = operands[0].GetIntegralAsS64({}).value();
1129               }
1130               const int64_t branch_dynamism_index = 2 + 2 * branch_index + 1;
1131               return std::move(operands[branch_dynamism_index]);
1132             }
1133           });
1134     }
1135     case HloOpcode::kGetTupleElement: {
1136       int64_t operand_handle = root->operand_ids(0);
1137       PostorderDFSNode result;
1138       context.shape_index.push_front(root->tuple_index());
1139       return PostorderDFSNode()
1140           .AddDependency(operand_handle, type, context)
1141           .AddVisit([](Literal operand) { return operand; });
1142     }
1143 
1144     case HloOpcode::kReduce: {
1145       return result.AddVisit([root, context](absl::Span<Literal> operands) {
1146         Shape root_shape = Shape(root->shape());
1147         Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
1148         std::unique_ptr<HloComputation> reduce_or;
1149         if (root_shape.IsTuple()) {
1150           // Variadic reduce.
1151           HloComputation::Builder b("reduce_or");
1152           std::vector<HloInstruction*> results;
1153           results.reserve(root_shape.tuple_shapes_size());
1154           for (int64_t i = 0; i < root_shape.tuple_shapes_size(); ++i) {
1155             auto lhs = b.AddInstruction(
1156                 HloInstruction::CreateParameter(i, scalar_shape, "lhs"));
1157             auto rhs = b.AddInstruction(HloInstruction::CreateParameter(
1158                 i + root_shape.tuple_shapes_size(), scalar_shape, "rhs"));
1159             results.push_back(b.AddInstruction(HloInstruction::CreateBinary(
1160                 scalar_shape, HloOpcode::kOr, lhs, rhs)));
1161           }
1162           b.AddInstruction(HloInstruction::CreateTuple(results));
1163           reduce_or = b.Build();
1164         } else {
1165           HloComputation::Builder b("reduce_or");
1166           auto lhs = b.AddInstruction(
1167               HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1168           auto rhs = b.AddInstruction(
1169               HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1170           b.AddInstruction(HloInstruction::CreateBinary(
1171               scalar_shape, HloOpcode::kOr, lhs, rhs));
1172           reduce_or = b.Build();
1173         }
1174 
1175         return HloProtoEvaluator(*root)
1176             .WithOperands(operands)
1177             .WithPrimitiveType(PRED)
1178             .WithComputation(std::move(reduce_or))
1179             // Reduce could produce tuple shape, only fetch what we need.
1180             .WithSubshape(context.shape_index)
1181             .Evaluate();
1182       });
1183     }
1184     case HloOpcode::kConstant:
1185     case HloOpcode::kIota: {
1186       return result.AddVisit(
1187           [root]() { return CreatePredLiteral(false, Shape(root->shape())); });
1188     }
1189     case HloOpcode::kParameter: {
1190       if (opcode == HloOpcode::kParameter &&
1191           !context.caller_operand_handles.empty()) {
1192         int64_t caller_operand = context.caller_operand_handles.back();
1193         context.caller_operand_handles.pop_back();
1194         return result.AddDependency(caller_operand, type, context)
1195             .AddVisit([](Literal literal) { return literal; });
1196       }
1197       return result.AddVisit([root, context]() {
1198         return CreatePredLiteral(
1199             true,
1200             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1201       });
1202     }
1203     case HloOpcode::kSelect: {
1204       return PostorderDFSNode()
1205           .AddDependency(root->operand_ids(0),
1206                          PostorderDFSNodeType::kConstantValue, context)
1207           .AddDependency(root->operand_ids(0),
1208                          PostorderDFSNodeType::kValueIsDynamic, context)
1209           // lhs dependency.
1210           .AddDependency(root->operand_ids(1), type, context)
1211           // rhs dependency.
1212           .AddDependency(root->operand_ids(2), type, context)
1213           .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
1214             OptionalLiteral optional_selector_literal(std::move(operands[0]),
1215                                                       std::move(operands[1]));
1216             Literal lhs = std::move(operands[2]);
1217             Literal rhs = std::move(operands[3]);
1218             auto result = CreatePredLiteral(true, Shape(root->shape()));
1219             result.MutableEachCell<bool>(
1220                 [&](absl::Span<const int64> indices, bool value) {
1221                   absl::optional<bool> optional_selector =
1222                       optional_selector_literal.Get<bool>(indices);
1223 
1224                   bool lhs_value = lhs.Get<bool>(indices);
1225                   bool rhs_value = rhs.Get<bool>(indices);
1226                   if (optional_selector.has_value()) {
1227                     // Manually evaluate the selection without using Evaluator.
1228                     if (*optional_selector) {
1229                       return lhs_value;
1230                     } else {
1231                       return rhs_value;
1232                     }
1233                   } else {
1234                     // Conservatively assume value is dynamic if selector is
1235                     // dynamic.
1236                     return true;
1237                   }
1238                 });
1239             return result;
1240           });
1241     }
1242     case HloOpcode::kGather: {
1243       return PostorderDFSNode()
1244           .AddDependency(root->operand_ids(0), type, context)
1245           .AddDependency(root->operand_ids(1),
1246                          PostorderDFSNodeType::kConstantValue, context)
1247           .AddDependency(root->operand_ids(1),
1248                          PostorderDFSNodeType::kValueIsDynamic, context)
1249           .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
1250             OptionalLiteral optional_selector_literal(std::move(operands[1]),
1251                                                       std::move(operands[2]));
1252 
1253             if (!optional_selector_literal.AllValid()) {
1254               // Conservatively assume results are dynamic.
1255               return CreatePredLiteral(true, Shape(root->shape()));
1256             }
1257             std::vector<Literal> new_operands;
1258             new_operands.emplace_back(std::move(operands[0]));
1259             new_operands.emplace_back(
1260                 optional_selector_literal.GetValue()->Clone());
1261 
1262             return HloProtoEvaluator(*root)
1263                 .WithOperands(absl::MakeSpan(new_operands))
1264                 .WithPrimitiveType(PRED)
1265                 .Evaluate();
1266           });
1267     }
1268     case HloOpcode::kCustomCall: {
1269       if (root->custom_call_target() == "SetBound") {
1270         return PostorderDFSNode().AddVisit([type, root]() -> StatusOr<Literal> {
1271           if (type == PostorderDFSNodeType::kBoundIsDynamic) {
1272             return CreatePredLiteral(false, Shape(root->shape()));
1273           } else {
1274             if (root->literal().shape().element_type() == TUPLE) {
1275               // First literal of SetBound contains bounds, second literal
1276               // contains dynamism indicators.
1277               return Literal::CreateFromProto(
1278                   root->literal().tuple_literals(1));
1279             } else {
1280               return Literal::CreateFromProto(root->literal());
1281             }
1282           }
1283         });
1284       } else if (root->custom_call_target() == "Sharding") {
1285         return result.AddVisit([](Literal operand) { return operand; });
1286       } else {
1287         return InvalidArgument(
1288             "Dynamic inferencing on custom call %s is not supported",
1289             root->DebugString());
1290       }
1291 
1292       break;
1293     }
1294 
1295     case HloOpcode::kCall:
1296     case HloOpcode::kRecv:
1297     case HloOpcode::kRecvDone:
1298     case HloOpcode::kSend:
1299     case HloOpcode::kSendDone:
1300     case HloOpcode::kWhile: {
1301       return PostorderDFSNode().AddVisit([root,
1302                                           context]() -> StatusOr<Literal> {
1303         return CreatePredLiteral(
1304             true,
1305             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1306       });
1307       break;
1308     }
1309     default:
1310       return PostorderDFSNode().AddVisit([root,
1311                                           context]() -> StatusOr<Literal> {
1312         return CreatePredLiteral(
1313             true,
1314             ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1315       });
1316   }
1317 }
1318 
PostOrderDFSVisit(int64_t handle,PostorderDFSNodeType type)1319 StatusOr<Literal> PostorderDFSVisitor::PostOrderDFSVisit(
1320     int64_t handle, PostorderDFSNodeType type) {
1321   enum VisitState {
1322     kUnvisited = 0,
1323     kVisiting,
1324     kVisited,
1325   };
1326 
1327   int64_t unique_id = 0;
1328   struct WorkItem {
1329     explicit WorkItem(int64_t handle, InferenceContext context,
1330                       PostorderDFSNodeType type, VisitState state, int64_t id)
1331         : handle(handle),
1332           context(std::move(context)),
1333           type(type),
1334           state(state),
1335           id(id) {}
1336     int64 handle;  // Handle of the node in the graph.
1337     InferenceContext context;
1338     PostorderDFSNodeType type;
1339     VisitState state;
1340     Visit visit;  // The handler to call once the dependencies are resolved into
1341                   // literal form.
1342     int64 id;     // Unique id in the work queue, starting from 0.
1343     std::vector<CacheKey> dependencies;
1344 
1345     CacheKey GetCacheKey() { return CacheKey(handle, context, type); }
1346   };
1347 
1348   std::vector<WorkItem> stack;
1349   WorkItem root(handle, InferenceContext({}, {}), type, kUnvisited,
1350                 unique_id++);
1351   stack.push_back(root);
1352   while (!stack.empty()) {
1353     WorkItem& item = stack.back();
1354     VLOG(1) << "stack top shape index: " << item.context.shape_index.ToString();
1355     VLOG(1) << "stack top "
1356             << handle_to_instruction(item.handle)->DebugString();
1357     if (item.state == kVisiting) {
1358       VLOG(1) << "visiting";
1359       // The dependencies are ready, visit the node itself.
1360 
1361       // Gather dependencies and transform them into literals.
1362       std::vector<Literal> literals;
1363       for (CacheKey& dep_key : item.dependencies) {
1364         TF_RET_CHECK(evaluated.contains(dep_key));
1365         literals.emplace_back(evaluated.at(dep_key).Clone());
1366       }
1367       VLOG(1) << "Start visiting with dependency type: "
1368               << PostorderDFSNodeTypeToString(item.type);
1369       TF_ASSIGN_OR_RETURN(auto literal, item.visit(absl::MakeSpan(literals)));
1370       VLOG(1) << "End visiting: " << literal.ToString();
1371       evaluated[item.GetCacheKey()] = std::move(literal);
1372       stack.pop_back();
1373       continue;
1374     }
1375     // This is the first time we see this node, we want to gather its
1376     // dependenceis.
1377     VLOG(1) << "unvisited";
1378     if (evaluated.contains(item.GetCacheKey())) {
1379       stack.pop_back();
1380       continue;
1381     }
1382     item.state = kVisiting;
1383     PostorderDFSNode node;
1384     switch (item.type) {
1385       case PostorderDFSNodeType::kConstantValue: {
1386         VLOG(1) << "constant value";
1387         TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle, item.context));
1388         break;
1389       }
1390       case PostorderDFSNodeType::kConstantLowerBound: {
1391         VLOG(1) << "constant lower bound";
1392         TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle, item.context));
1393         break;
1394       }
1395       case PostorderDFSNodeType::kConstantUpperBound: {
1396         VLOG(1) << "constant upper bound";
1397         TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle, item.context));
1398         break;
1399       }
1400       case PostorderDFSNodeType::kBoundIsDynamic:
1401       case PostorderDFSNodeType::kValueIsDynamic: {
1402         VLOG(1) << "value is dynamic";
1403         TF_ASSIGN_OR_RETURN(
1404             node, AnalyzeIsDynamic(item.handle, item.type, item.context));
1405         break;
1406       }
1407     }
1408     // Store the visit function which is needed when its dependencies are
1409     // resolved.
1410     item.visit = node.visit;
1411 
1412     const int64 current_item_id = stack.size() - 1;
1413     // Enqueue dependencies into the stack. `item` shouldn't be accessed after
1414     // this point.
1415     for (const PostorderDFSDep& dep : node.dependencies) {
1416       VLOG(1) << "dep " << dep.annotation
1417               << "::" << handle_to_instruction(dep.handle)->DebugString()
1418               << "index" << dep.context.shape_index
1419               << " stack size:" << stack.size();
1420       stack.emplace_back(dep.handle, dep.context, dep.type, kUnvisited,
1421                          unique_id++);
1422       stack[current_item_id].dependencies.push_back(stack.back().GetCacheKey());
1423     }
1424   }
1425   VLOG(1) << "done" << evaluated[root.GetCacheKey()].ToString();
1426   return evaluated[root.GetCacheKey()].Clone();
1427 }
1428 
AnalyzeIsDynamic(XlaOp op)1429 StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
1430   PostorderDFSVisitor visitor(
1431       [&](int64_t handle) {
1432         return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1433       },
1434       [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1435   auto result = visitor.PostOrderDFSVisit(
1436       op.handle(), PostorderDFSNodeType::kValueIsDynamic);
1437   return result;
1438 }
1439 
CseOpHandle(int64_t handle)1440 absl::optional<int64> ValueInference::CseOpHandle(int64_t handle) {
1441   auto inst = builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1442   HloOpcode opcode = StringToHloOpcode(inst->opcode()).ValueOrDie();
1443   // For now, only handle kGetDimensionSize as that's the most duplicated one.
1444   if (opcode != HloOpcode::kGetDimensionSize) {
1445     return absl::nullopt;
1446   }
1447   int64_t hash = inst->operand_ids(0);
1448   hash =
1449       tensorflow::Hash64Combine(hash, std::hash<int64>()(inst->dimensions(0)));
1450   auto lookup = cse_map_.find(hash);
1451   if (lookup == cse_map_.end()) {
1452     cse_map_[hash] = handle;
1453     return absl::nullopt;
1454   }
1455   auto equivalent_op =
1456       builder_->LookUpInstructionByHandle(lookup->second).ValueOrDie();
1457   // Check that the op is indeed equivalent to prevent hash collision --
1458   // relatively easy to happen with 64 bits hash.
1459   if (equivalent_op->opcode() != inst->opcode() ||
1460       equivalent_op->operand_ids(0) != inst->operand_ids(0) ||
1461       equivalent_op->dimensions(0) != inst->dimensions(0)) {
1462     // Hash collision, don't CSE.
1463     return absl::nullopt;
1464   }
1465   int64_t cse = lookup->second;
1466   if (handle != cse) {
1467     // Successfully found a handle that's not the same as input but equivalent.
1468     return cse;
1469   }
1470   return absl::nullopt;
1471 }
1472 
SimplifyOp(int64_t handle)1473 StatusOr<Literal> ValueInference::SimplifyOp(int64_t handle) {
1474   if (auto cse_handle = CseOpHandle(handle)) {
1475     // Use the CSE'd handle instead.
1476     return SimplifyOp(*cse_handle);
1477   }
1478   auto inst = builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1479   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode()));
1480   std::vector<Literal> operands;
1481   auto output_shape = Shape(inst->shape());
1482   switch (opcode) {
1483     case HloOpcode::kSlice:
1484     case HloOpcode::kConcatenate:
1485     case HloOpcode::kReshape:
1486     case HloOpcode::kBroadcast: {
1487       for (auto operand_id : inst->operand_ids()) {
1488         TF_ASSIGN_OR_RETURN(auto literal, SimplifyOp(operand_id));
1489         operands.emplace_back(std::move(literal));
1490       }
1491       // We put handles into the tensor and evaluate the results into a literal.
1492       // The literal also contain handles for each element position.
1493       return HloProtoEvaluator(*inst)
1494           .WithOperands(absl::MakeSpan(operands))
1495           .WithPrimitiveType(S64)
1496           .Evaluate();
1497     }
1498     case HloOpcode::kConvert: {
1499       // Only identity kConvert can be optimized away.
1500       auto operand = builder_->LookUpInstructionByHandle(inst->operand_ids(0))
1501                          .ValueOrDie();
1502       if (Shape::Equal()(output_shape, Shape(operand->shape()))) {
1503         // Forward operand handle as result.
1504         return SimplifyOp(inst->operand_ids(0));
1505       } else {
1506         return CreateS64Literal(-1, output_shape);
1507       }
1508     }
1509     case HloOpcode::kAdd: {
1510       // a + (b - a) => b
1511       // a + b + (c - a) => b + c
1512       if (output_shape.rank() == 0) {
1513         TF_ASSIGN_OR_RETURN(auto lhs, SimplifyOp(inst->operand_ids(0)));
1514         TF_ASSIGN_OR_RETURN(auto rhs, SimplifyOp(inst->operand_ids(1)));
1515         int64_t lhs_handle = lhs.Get<int64>({});
1516         int64_t rhs_handle = rhs.Get<int64>({});
1517         if (lhs_handle == -1 || rhs_handle == -1) {
1518           return CreateS64Literal(-1, output_shape);
1519         }
1520         // Recursive lambda needs explicit signature.
1521         std::function<absl::optional<int64>(int64_t, int64_t)> can_be_optimized;
1522         can_be_optimized = [this, &can_be_optimized](
1523                                int64_t lhs,
1524                                int64_t rhs) -> absl::optional<int64> {
1525           auto rhs_inst = builder_->LookUpInstructionByHandle(rhs).ValueOrDie();
1526           HloOpcode rhs_opcode =
1527               StringToHloOpcode(rhs_inst->opcode()).ValueOrDie();
1528           if (rhs_opcode == HloOpcode::kSubtract) {
1529             auto sub_lhs_handle = SimplifyOp(rhs_inst->operand_ids(0))
1530                                       .ValueOrDie()
1531                                       .Get<int64>({});
1532             auto sub_rhs_handle = SimplifyOp(rhs_inst->operand_ids(1))
1533                                       .ValueOrDie()
1534                                       .Get<int64>({});
1535             if (sub_rhs_handle == lhs) {
1536               // lhs + (sub_lhs - sub_rhs) = sub_lhs if lhs == sub_rhs
1537               return sub_lhs_handle;
1538             }
1539           }
1540 
1541           // Check the case for a + b + (c - a) => b + c
1542           auto lhs_inst = builder_->LookUpInstructionByHandle(lhs).ValueOrDie();
1543           HloOpcode lhs_opcode =
1544               StringToHloOpcode(lhs_inst->opcode()).ValueOrDie();
1545           if (lhs_opcode == HloOpcode::kAdd) {
1546             auto add_lhs_handle = SimplifyOp(lhs_inst->operand_ids(0))
1547                                       .ValueOrDie()
1548                                       .Get<int64>({});
1549             auto add_rhs_handle = SimplifyOp(lhs_inst->operand_ids(1))
1550                                       .ValueOrDie()
1551                                       .Get<int64>({});
1552             if (auto optimized = can_be_optimized(add_lhs_handle, rhs)) {
1553               return Add(XlaOp(add_rhs_handle, builder_),
1554                          XlaOp(optimized.value(), builder_))
1555                   .handle();
1556             }
1557             if (auto optimized = can_be_optimized(add_rhs_handle, rhs)) {
1558               return Add(XlaOp(add_lhs_handle, builder_),
1559                          XlaOp(optimized.value(), builder_))
1560                   .handle();
1561             }
1562           }
1563           return absl::nullopt;
1564         };
1565         if (auto optimized = can_be_optimized(lhs_handle, rhs_handle)) {
1566           return LiteralUtil::CreateR0<int64>(optimized.value());
1567         }
1568         // Swap lhs and rhs.
1569         if (auto optimized = can_be_optimized(rhs_handle, lhs_handle)) {
1570           return LiteralUtil::CreateR0<int64>(optimized.value());
1571         }
1572         // This sum can't be optimized, return sum of lhs and rhs. Note that we
1573         // can't just return the original sum as its lhs and rhs could be
1574         // optimized and different.
1575         XlaOp new_sum =
1576             Add(XlaOp(lhs_handle, builder_), XlaOp(rhs_handle, builder_));
1577 
1578         return LiteralUtil::CreateR0<int64>(new_sum.handle());
1579       } else {
1580         return CreateS64Literal(-1, output_shape);
1581       }
1582     }
1583     default: {
1584       if (ShapeUtil::IsScalar(output_shape)) {
1585         return LiteralUtil::CreateR0<int64>(handle);
1586       } else {
1587         return CreateS64Literal(-1, output_shape);
1588       }
1589     }
1590   }
1591 }
1592 
AnalyzeConstant(XlaOp op,ValueInferenceMode mode)1593 StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
1594     XlaOp op, ValueInferenceMode mode) {
1595   PostorderDFSVisitor visitor(
1596       [&](int64_t handle) {
1597         return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1598       },
1599       [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1600   int64_t handle = op.handle();
1601   if (ShapeUtil::IsScalar(builder_->GetShape(op).ValueOrDie())) {
1602     TF_ASSIGN_OR_RETURN(auto result, SimplifyOp(handle));
1603     auto optimized_handle = result.Get<int64>({});
1604     if (optimized_handle != -1) {
1605       handle = optimized_handle;
1606     }
1607   }
1608   switch (mode) {
1609     case ValueInferenceMode::kLowerBound: {
1610       TF_ASSIGN_OR_RETURN(
1611           Literal value,
1612           visitor.PostOrderDFSVisit(handle,
1613                                     PostorderDFSNodeType::kConstantLowerBound));
1614       TF_ASSIGN_OR_RETURN(Literal mask,
1615                           visitor.PostOrderDFSVisit(
1616                               handle, PostorderDFSNodeType::kBoundIsDynamic));
1617       return OptionalLiteral(std::move(value), std::move(mask));
1618     }
1619     case ValueInferenceMode::kUpperBound: {
1620       TF_ASSIGN_OR_RETURN(
1621           Literal value,
1622           visitor.PostOrderDFSVisit(handle,
1623                                     PostorderDFSNodeType::kConstantUpperBound));
1624       TF_ASSIGN_OR_RETURN(Literal mask,
1625                           visitor.PostOrderDFSVisit(
1626                               handle, PostorderDFSNodeType::kBoundIsDynamic));
1627 
1628       return OptionalLiteral(std::move(value), std::move(mask));
1629     }
1630     case ValueInferenceMode::kValue: {
1631       TF_ASSIGN_OR_RETURN(Literal value,
1632                           visitor.PostOrderDFSVisit(
1633                               handle, PostorderDFSNodeType::kConstantValue));
1634       TF_ASSIGN_OR_RETURN(Literal mask,
1635                           visitor.PostOrderDFSVisit(
1636                               handle, PostorderDFSNodeType::kValueIsDynamic));
1637       return OptionalLiteral(std::move(value), std::move(mask));
1638     }
1639   }
1640 }
1641 
1642 }  // namespace xla
1643