1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/value_inference.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/comparison_util.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/primitive_util.h"
29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
30 #include "tensorflow/compiler/xla/service/hlo.pb.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/stream_executor/lib/statusor.h"
40
41 namespace xla {
42 namespace {
CreatePredLiteral(bool pred,const Shape & reference_shape)43 Literal CreatePredLiteral(bool pred, const Shape& reference_shape) {
44 if (reference_shape.IsTuple()) {
45 std::vector<Literal> sub_literals;
46 for (const Shape& shape : reference_shape.tuple_shapes()) {
47 sub_literals.emplace_back(CreatePredLiteral(pred, shape));
48 }
49 return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
50 }
51 PrimitiveType element_type = reference_shape.element_type();
52 if (element_type == TOKEN) {
53 return LiteralUtil::CreateR0(pred);
54 }
55 Literal literal = LiteralUtil::CreateR0(pred);
56 Literal literal_broadcast =
57 literal.Broadcast(ShapeUtil::ChangeElementType(reference_shape, PRED), {})
58 .ValueOrDie();
59 return literal_broadcast;
60 }
61
CreateS64Literal(int64_t value,const Shape & reference_shape)62 Literal CreateS64Literal(int64_t value, const Shape& reference_shape) {
63 if (reference_shape.IsTuple()) {
64 std::vector<Literal> sub_literals;
65 for (const Shape& shape : reference_shape.tuple_shapes()) {
66 sub_literals.emplace_back(CreateS64Literal(value, shape));
67 }
68 return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
69 }
70 PrimitiveType element_type = reference_shape.element_type();
71 if (element_type == TOKEN) {
72 return LiteralUtil::CreateToken();
73 }
74 Literal literal = LiteralUtil::CreateR0<int64>(value);
75 return literal
76 .Broadcast(ShapeUtil::ChangeElementType(reference_shape, S64), {})
77 .ValueOrDie();
78 }
79
80 // Create a literal with garbage data. The data inside is undefined and
81 // shouldn't be used in any meaningful computation.
CreateGarbageLiteral(const Shape & reference_shape)82 Literal CreateGarbageLiteral(const Shape& reference_shape) {
83 if (reference_shape.IsTuple()) {
84 std::vector<Literal> sub_literals;
85 for (const Shape& shape : reference_shape.tuple_shapes()) {
86 sub_literals.emplace_back(CreateGarbageLiteral(shape));
87 }
88 return Literal::MoveIntoTuple(absl::MakeSpan(sub_literals));
89 }
90 PrimitiveType element_type = reference_shape.element_type();
91 if (element_type == TOKEN) {
92 return LiteralUtil::CreateToken();
93 }
94 Literal literal = LiteralUtil::One(element_type);
95 return literal.Broadcast(reference_shape, {}).ValueOrDie();
96 }
97
98 // HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
99 // to provide operand as literals through the get_operand function.
100 struct HloProtoEvaluator {
HloProtoEvaluatorxla::__anon68edb0ec0111::HloProtoEvaluator101 explicit HloProtoEvaluator(HloInstructionProto inst)
102 : inst(std::move(inst)),
103 module("EmptyModuleForEvaluation", HloModuleConfig()) {}
104
105 // WithOpCode changes the called computation of the instruction being
106 // evaluated.
WithComputationxla::__anon68edb0ec0111::HloProtoEvaluator107 HloProtoEvaluator& WithComputation(
108 std::unique_ptr<HloComputation> new_computation) {
109 computation = new_computation.get();
110 computation->ClearUniqueIdInternal();
111 for (HloInstruction* inst : computation->instructions()) {
112 inst->ClearUniqueIdInternal();
113 }
114 module.AddEmbeddedComputation(std::move(new_computation));
115 return *this;
116 }
117
118 // WithPrimitiveType changes the primitive type of the instruction being
119 // evaluated.
WithPrimitiveTypexla::__anon68edb0ec0111::HloProtoEvaluator120 HloProtoEvaluator& WithPrimitiveType(PrimitiveType new_primitive_type) {
121 primitive_type = new_primitive_type;
122 return *this;
123 }
124
125 // WithOpCode changes the opcode of the instruction being evaluated.
WithOpCodexla::__anon68edb0ec0111::HloProtoEvaluator126 HloProtoEvaluator& WithOpCode(HloOpcode new_opcode) {
127 opcode = new_opcode;
128 return *this;
129 }
130
131 // WithOperands changes the operands of the instruction being evaluated.
WithOperandsxla::__anon68edb0ec0111::HloProtoEvaluator132 HloProtoEvaluator& WithOperands(absl::Span<Literal> operands) {
133 this->operands = operands;
134 return *this;
135 }
136
137 // When WithSubshape is set, the result tuple shape will be decomposed and
138 // specific the literal will be returned.
WithSubshapexla::__anon68edb0ec0111::HloProtoEvaluator139 HloProtoEvaluator& WithSubshape(ShapeIndex shape_index) {
140 this->shape_index = std::move(shape_index);
141 return *this;
142 }
143
Evaluatexla::__anon68edb0ec0111::HloProtoEvaluator144 StatusOr<Literal> Evaluate() {
145 // Evaluate the instruction by swapping it's operands with constant
146 // instructions with given literals.
147 HloComputation::Builder builder("EmptyComputation");
148 absl::flat_hash_map<int64, HloInstruction*> operand_map;
149 for (int64_t i = 0; i < inst.operand_ids_size(); ++i) {
150 int64_t operand_handle = inst.operand_ids(i);
151 std::unique_ptr<HloInstruction> operand =
152 HloInstruction::CreateConstant(operands[i].Clone());
153 operand_map[operand_handle] = operand.get();
154 builder.AddInstruction(std::move(operand));
155 }
156
157 if (primitive_type.has_value()) {
158 *inst.mutable_shape() = ShapeUtil::ChangeElementType(
159 Shape(inst.shape()), primitive_type.value())
160 .ToProto();
161 }
162 if (opcode.has_value()) {
163 *inst.mutable_opcode() = HloOpcodeString(opcode.value());
164 }
165 absl::flat_hash_map<int64, HloComputation*> computation_map;
166 if (inst.called_computation_ids_size() != 0) {
167 TF_RET_CHECK(inst.called_computation_ids_size() == 1 &&
168 computation != nullptr)
169 << inst.DebugString();
170 computation_map[inst.called_computation_ids(0)] = computation;
171 }
172 TF_ASSIGN_OR_RETURN(
173 auto new_instruction,
174 HloInstruction::CreateFromProto(inst, operand_map, computation_map));
175 new_instruction->ClearUniqueIdInternal();
176 builder.AddInstruction(std::move(new_instruction));
177 auto computation = builder.Build();
178 module.AddEntryComputation(std::move(computation));
179 HloEvaluator evaluator;
180 if (shape_index.empty()) {
181 return evaluator.Evaluate(module.entry_computation()->root_instruction());
182 } else {
183 TF_ASSIGN_OR_RETURN(
184 auto result,
185 evaluator.Evaluate(module.entry_computation()->root_instruction()));
186 return result.SubLiteral(this->shape_index);
187 }
188 }
189
190 HloInstructionProto inst;
191
192 HloModule module;
193 absl::Span<Literal> operands;
194 ShapeIndex shape_index = {};
195 HloComputation* computation = nullptr;
196 absl::optional<PrimitiveType> primitive_type = absl::nullopt;
197 absl::optional<HloOpcode> opcode = absl::nullopt;
198 };
199
200 enum PostorderDFSNodeType {
201 // This node is about figuring out the constant value.
202 kConstantValue = 0,
203 // This node is about figuring out the constant bound.
204 kConstantUpperBound,
205 kConstantLowerBound,
206 // This node is about figuring out whether a value is dynamic.
207 kValueIsDynamic,
208 // This node is about figuring out whether a bound value is dynamic. It's
209 // similar to kValueIsDynamic, but views shape bound as static values.
210 kBoundIsDynamic,
211 };
212
PostorderDFSNodeTypeToString(PostorderDFSNodeType type)213 std::string PostorderDFSNodeTypeToString(PostorderDFSNodeType type) {
214 switch (type) {
215 case kConstantValue:
216 return "kConstantValue";
217 case kConstantUpperBound:
218 return "kConstantUpperBound";
219 case kConstantLowerBound:
220 return "kConstantLowerBound";
221 case kValueIsDynamic:
222 return "kValueIsDynamic";
223 case kBoundIsDynamic:
224 return "kBoundIsDynamic";
225 }
226 }
227
228 struct InferenceContext {
InferenceContextxla::__anon68edb0ec0111::InferenceContext229 explicit InferenceContext(ShapeIndex shape_index,
230 std::vector<int64> caller_operand_handles)
231 : shape_index(std::move(shape_index)),
232 caller_operand_handles(std::move(caller_operand_handles)) {}
233 // `shape_index` represents the subshape that we care about in the inference.
234 // It is used to avoid meterializing the whole tuple when we only care about a
235 // sub tensor of it.
236 ShapeIndex shape_index;
237
238 // caller_operand_handles is a stack that helps argument forwarding. The top
239 // of the stack represents the tensor to be forwarded to the
240 // parameter of the inner most function. E.g.,:
241 // inner_true_computation {
242 // inner_p0 = param()
243 // ...
244 // }
245 //
246 // true_computaion {
247 // p0 = param()
248 // conditional(pred, p0, inner_true_computation,
249 // ...)
250 // }
251 //
252 // main {
253 // op = ..
254 // conditional(pred, op, true_computation, ...)
255 // }
256 //
257 // In this case, when we analyze inner_true_computation, the
258 // `caller_operand_handlers` will be [op, p0] -- p0 is what should be
259 // forwarded to inner_p0 and op is what should be forwarded to p0. similarly,
260 // when we analyze true_computation, the `caller_operand_handlers` will be
261 // [op].
262 std::vector<int64> caller_operand_handles;
263 };
264
265 // Each node in the postorder traversal tree may depend on traversing the
266 // values of the node's children.
267 struct PostorderDFSDep {
PostorderDFSDepxla::__anon68edb0ec0111::PostorderDFSDep268 explicit PostorderDFSDep(int64_t handle, PostorderDFSNodeType type,
269 InferenceContext context, std::string annotation)
270 : handle(handle),
271 type(type),
272 context(std::move(context)),
273 annotation(std::move(annotation)) {}
274 int64 handle;
275 PostorderDFSNodeType type;
276 InferenceContext context;
277 std::string annotation;
278 };
279
280 // This function represents the logic to visit a node once its dependencies
281 // (operands) are all resolved.
282 using Visit = std::function<StatusOr<Literal>(absl::Span<Literal>)>;
283 // Convenient specializations of Visit function for different operands.
284 using Visit0D = std::function<StatusOr<Literal>()>;
285 using Visit1D = std::function<StatusOr<Literal>(Literal)>;
286 using Visit2D = std::function<StatusOr<Literal>(Literal, Literal)>;
287
288 // A postorder dfs node can be visited once its dependency requests are all
289 // fulfilled.
290 struct TF_MUST_USE_RESULT PostorderDFSNode {
AddDependencyxla::__anon68edb0ec0111::PostorderDFSNode291 PostorderDFSNode& AddDependency(int64_t handle, PostorderDFSNodeType type,
292 InferenceContext context,
293 std::string annotation = "") {
294 dependencies.emplace_back(handle, type, std::move(context),
295 std::move(annotation));
296 return *this;
297 }
298
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode299 PostorderDFSNode& AddVisit(const Visit& visit) {
300 this->visit = visit;
301 return *this;
302 }
303
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode304 PostorderDFSNode& AddVisit(const Visit0D& visit) {
305 this->visit = [visit](absl::Span<Literal> literals) { return visit(); };
306 return *this;
307 }
308
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode309 PostorderDFSNode& AddVisit(const Visit1D& visit) {
310 this->visit = [visit](absl::Span<Literal> literals) {
311 return visit(std::move(literals[0]));
312 };
313 return *this;
314 }
315
AddVisitxla::__anon68edb0ec0111::PostorderDFSNode316 PostorderDFSNode& AddVisit(const Visit2D& visit) {
317 this->visit = [visit](absl::Span<Literal> literals) {
318 return visit(std::move(literals[0]), std::move(literals[1]));
319 };
320 return *this;
321 }
322 std::vector<PostorderDFSDep> dependencies;
323 Visit visit;
324 };
325
326 // Convert an interger handle to HloInstructionProto.
327 using HandleToInstruction = std::function<const HloInstructionProto*(int64_t)>;
328 using HandleToComputation = std::function<const HloComputationProto*(int64_t)>;
329
330 struct PostorderDFSVisitor {
PostorderDFSVisitorxla::__anon68edb0ec0111::PostorderDFSVisitor331 PostorderDFSVisitor(HandleToInstruction handle_to_instruction,
332 HandleToComputation handle_to_computation)
333 : handle_to_instruction(handle_to_instruction),
334 handle_to_computation(handle_to_computation) {}
335
336 StatusOr<PostorderDFSNode> AnalyzeUpperBound(int64_t handle,
337 InferenceContext context);
338 StatusOr<PostorderDFSNode> AnalyzeLowerBound(int64_t handle,
339 InferenceContext context);
340 StatusOr<PostorderDFSNode> AnalyzeIsDynamic(int64_t handle,
341 PostorderDFSNodeType type,
342 InferenceContext context);
343 StatusOr<PostorderDFSNode> AnalyzeConstant(int64_t handle,
344 InferenceContext context);
345 StatusOr<PostorderDFSNode> AnalyzeConstantValueFallback(
346 int64_t handle, PostorderDFSNodeType type, InferenceContext context);
347
348 StatusOr<Literal> PostOrderDFSVisit(int64_t handle,
349 PostorderDFSNodeType type);
350
351 // Returns true if a value represented by `handle` is an integeral type or
352 // a floating pointer type that just got converted from an integral type.
353 // E.g.,:
354 // 1 -> true
355 // float(1) -> true
356 // 1.1 -> false
357 // 1.0 -> false -- We don't know the concrete value at runtime, except for its
358 // type.
IsValueEffectiveIntegerxla::__anon68edb0ec0111::PostorderDFSVisitor359 bool IsValueEffectiveInteger(int64_t handle) {
360 const HloInstructionProto* instr = handle_to_instruction(handle);
361 if (primitive_util::IsIntegralType(instr->shape().element_type())) {
362 return true;
363 }
364 // Also returns true if this is a convert that converts an integer to float.
365 HloOpcode opcode = StringToHloOpcode(instr->opcode()).ValueOrDie();
366 if (opcode != HloOpcode::kConvert) {
367 return false;
368 }
369 const HloInstructionProto* parent =
370 handle_to_instruction(instr->operand_ids(0));
371 if (primitive_util::IsIntegralType(parent->shape().element_type())) {
372 return true;
373 }
374 return false;
375 }
376
377 struct CacheKey {
CacheKeyxla::__anon68edb0ec0111::PostorderDFSVisitor::CacheKey378 CacheKey(int64 handle, InferenceContext context, PostorderDFSNodeType type)
379 : handle(handle), context(context), type(type) {}
380 int64 handle;
381 InferenceContext context;
382 PostorderDFSNodeType type;
383
384 template <typename H>
AbslHashValuexla::__anon68edb0ec0111::PostorderDFSVisitor385 friend H AbslHashValue(H h, const CacheKey& key) {
386 h = H::combine(std::move(h), key.handle);
387 h = H::combine(std::move(h), key.context.shape_index.ToString());
388 h = H::combine(std::move(h),
389 VectorString(key.context.caller_operand_handles));
390 h = H::combine(std::move(h), key.type);
391 return h;
392 }
393
operator ==xla::__anon68edb0ec0111::PostorderDFSVisitor394 friend bool operator==(const CacheKey& lhs, const CacheKey& rhs) {
395 return lhs.handle == rhs.handle &&
396 lhs.context.shape_index == rhs.context.shape_index &&
397 lhs.context.caller_operand_handles ==
398 rhs.context.caller_operand_handles &&
399 lhs.type == rhs.type;
400 }
401 };
402
403 absl::flat_hash_map<CacheKey, Literal> evaluated;
404 HandleToInstruction handle_to_instruction;
405 HandleToComputation handle_to_computation;
406 };
407
408 } // namespace
409
410 // Analyze a tensor's constant value, upper-bound value or lower-bound value.
AnalyzeConstantValueFallback(int64_t handle,PostorderDFSNodeType type,InferenceContext context)411 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstantValueFallback(
412 int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
413 const HloInstructionProto* root = handle_to_instruction(handle);
414 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
415 PostorderDFSNode result;
416 // By default, the dependencies of current node are its operands.
417 for (auto operand_id : root->operand_ids()) {
418 InferenceContext dep_context = context;
419 dep_context.shape_index = {};
420 result.AddDependency(operand_id, type, dep_context);
421 }
422 switch (opcode) {
423 // Non functional ops.
424 case HloOpcode::kRng:
425 case HloOpcode::kAllReduce:
426 case HloOpcode::kReduceScatter:
427 case HloOpcode::kInfeed:
428 case HloOpcode::kOutfeed:
429 case HloOpcode::kCall:
430 case HloOpcode::kCustomCall:
431 case HloOpcode::kWhile:
432 case HloOpcode::kSend:
433 case HloOpcode::kRecv:
434 case HloOpcode::kSendDone:
435 case HloOpcode::kRecvDone:
436 case HloOpcode::kParameter: {
437 if (opcode == HloOpcode::kParameter &&
438 !context.caller_operand_handles.empty()) {
439 int64_t caller_operand = context.caller_operand_handles.back();
440 context.caller_operand_handles.pop_back();
441 return result.AddDependency(caller_operand, type, context)
442 .AddVisit([](Literal literal) { return literal; });
443 }
444 return PostorderDFSNode().AddVisit([root, context](absl::Span<Literal>) {
445 // The value is dynamic. We return a garbage literal here, which
446 // will be masked out later.
447 return CreateGarbageLiteral(
448 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
449 });
450 }
451 // Subtract and Divide use lower-bound as second operand.
452 case HloOpcode::kSubtract:
453 case HloOpcode::kCos:
454 case HloOpcode::kSin:
455 case HloOpcode::kNegate:
456 case HloOpcode::kAbs:
457 case HloOpcode::kDivide:
458 case HloOpcode::kGetDimensionSize: {
459 return InvalidArgument(
460 "AnalyzeConstantValueFallback can't handle opcode: %s",
461 root->opcode());
462 }
463
464 case HloOpcode::kConditional: {
465 auto node = PostorderDFSNode();
466 auto* conditional_proto = root;
467 InferenceContext predicate_context = context;
468 predicate_context.shape_index = {};
469 // Add dependencies to analyze the predicate of the conditional.
470 node.AddDependency(conditional_proto->operand_ids(0),
471 PostorderDFSNodeType::kConstantValue,
472 predicate_context)
473 .AddDependency(conditional_proto->operand_ids(0),
474 PostorderDFSNodeType::kValueIsDynamic,
475 predicate_context);
476 const int64_t branch_size =
477 conditional_proto->called_computation_ids_size();
478 for (int64_t i = 0; i < branch_size; ++i) {
479 int64_t branch_root =
480 handle_to_computation(conditional_proto->called_computation_ids(i))
481 ->root_id();
482 InferenceContext branch_context = context;
483 branch_context.caller_operand_handles.push_back(
484 conditional_proto->operand_ids(i + 1));
485 node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
486 branch_context);
487 }
488 return node.AddVisit(
489 [](absl::Span<Literal> operands) -> StatusOr<Literal> {
490 int64_t pred_is_dynamic = operands[1].Get<bool>({});
491 if (pred_is_dynamic) {
492 // If predicate is dynamic, return the value of the first branch
493 // -- If all branches return the same value, this is the value
494 // that we want; If not, the value will be masked anyway so the
495 // value inside doesn't matter.
496 return std::move(operands[2]);
497 } else {
498 // If predicate is static, return the value of the given branch.
499 int64_t branch_index = 0;
500 if (operands[0].shape().element_type() == PRED) {
501 if (operands[0].Get<bool>({})) {
502 branch_index = 0;
503 } else {
504 branch_index = 1;
505 }
506 } else {
507 branch_index = operands[0].GetIntegralAsS64({}).value();
508 }
509 const int64_t branch_dynamism_index = 2 + branch_index;
510 return std::move(operands[branch_dynamism_index]);
511 }
512 });
513 }
514 case HloOpcode::kGetTupleElement: {
515 int64_t operand_handle = root->operand_ids(0);
516 PostorderDFSNode result;
517 context.shape_index.push_front(root->tuple_index());
518 return PostorderDFSNode()
519 .AddDependency(operand_handle, type, context)
520 .AddVisit([](Literal operand) { return operand; });
521 }
522 case HloOpcode::kReduce:
523 case HloOpcode::kScatter:
524 case HloOpcode::kReduceWindow: {
525 const HloComputationProto* computation_proto =
526 handle_to_computation(root->called_computation_ids(0));
527 return result.AddVisit(
528 [root, computation_proto,
529 context](absl::Span<Literal> operands) -> StatusOr<Literal> {
530 TF_ASSIGN_OR_RETURN(
531 auto computation,
532 HloComputation::CreateFromProto(*computation_proto, {}));
533 return HloProtoEvaluator(*root)
534 .WithOperands(operands)
535 .WithComputation(std::move(computation))
536 .WithSubshape(context.shape_index)
537 .Evaluate();
538 });
539 }
540 default: {
541 if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
542 // There could be many operands of a tuple, but only one that we are
543 // interested in, represented by `tuple_operand_index`.
544 int64_t tuple_operand_index = context.shape_index.front();
545 InferenceContext tuple_operand_context = context;
546 tuple_operand_context.shape_index.pop_front();
547 return PostorderDFSNode()
548 .AddDependency(root->operand_ids(tuple_operand_index), type,
549 tuple_operand_context)
550 .AddVisit([](Literal operand) { return operand; });
551 }
552 return result.AddVisit([root](absl::Span<Literal> operands) {
553 return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
554 });
555 }
556 }
557 }
558
AnalyzeUpperBound(int64_t handle,InferenceContext context)559 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeUpperBound(
560 int64_t handle, InferenceContext context) {
561 const HloInstructionProto* root = handle_to_instruction(handle);
562 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
563 switch (opcode) {
564 case HloOpcode::kGetDimensionSize: {
565 int64_t dimension = root->dimensions(0);
566 int64_t operand_handle = root->operand_ids(0);
567 const HloInstructionProto* operand_proto =
568 handle_to_instruction(operand_handle);
569 return PostorderDFSNode().AddVisit(
570 [operand_proto, dimension]() -> StatusOr<Literal> {
571 return LiteralUtil::CreateR0<int32>(
572 operand_proto->shape().dimensions(dimension));
573 });
574 }
575 case HloOpcode::kAbs: {
576 // upper-bound(abs(operand)) = max(abs(lower-bound(operand)),
577 // abs(upper-bound(operand)))
578 return PostorderDFSNode()
579 .AddDependency(root->operand_ids(0),
580 PostorderDFSNodeType::kConstantLowerBound, context)
581 .AddDependency(root->operand_ids(0),
582 PostorderDFSNodeType::kConstantUpperBound, context)
583 .AddVisit([](Literal lower_bound,
584 Literal upper_bound) -> StatusOr<Literal> {
585 HloEvaluator evaluator;
586 TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
587 evaluator.EvaluateElementwiseUnaryOp(
588 HloOpcode::kAbs, lower_bound));
589 TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
590 evaluator.EvaluateElementwiseUnaryOp(
591 HloOpcode::kAbs, upper_bound));
592 return evaluator.EvaluateElementwiseBinaryOp(
593 HloOpcode::kMaximum, lower_bound_abs, upper_bound_abs);
594 });
595 }
596 case HloOpcode::kSort: {
597 auto dfs = PostorderDFSNode();
598 InferenceContext dep_context = context;
599 dep_context.shape_index = {};
600 if (!context.shape_index.empty()) {
601 // Lazy evaluation: Only need to evaluate a subelement in a
602 // variadic-sort tensor.
603 dfs.AddDependency(root->operand_ids(context.shape_index[0]),
604 PostorderDFSNodeType::kConstantUpperBound,
605 dep_context);
606 } else {
607 for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
608 dfs.AddDependency(root->operand_ids(i),
609 PostorderDFSNodeType::kConstantUpperBound,
610 dep_context);
611 }
612 }
613
614 return dfs.AddVisit(
615 [root, context](absl::Span<Literal> operands) -> StatusOr<Literal> {
616 std::vector<Literal> results;
617 results.reserve(operands.size());
618 // Conservatively set each element of the tensor to the max value.
619 for (int64_t i = 0; i < operands.size(); ++i) {
620 auto max = LiteralUtil::MaxElement(operands[i]);
621 results.emplace_back(
622 max.Broadcast(operands[i].shape(), {}).ValueOrDie());
623 }
624 if (ShapeUtil::GetSubshape(Shape(root->shape()),
625 context.shape_index)
626 .IsTuple()) {
627 return LiteralUtil::MakeTupleOwned(std::move(results));
628 } else {
629 return std::move(results[0]);
630 }
631 });
632 }
633 case HloOpcode::kNegate: {
634 // upper-bound(negate(operand)) = negate(lower-bound(operand))
635 return PostorderDFSNode()
636 .AddDependency(root->operand_ids(0),
637 PostorderDFSNodeType::kConstantLowerBound, context)
638 .AddVisit([](Literal lower_bound) -> StatusOr<Literal> {
639 HloEvaluator evaluator;
640 return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
641 lower_bound);
642 });
643 }
644 case HloOpcode::kSubtract:
645 case HloOpcode::kDivide: {
646 // Lower-bound is used for second operand of subtract and divide.
647 return PostorderDFSNode()
648 .AddDependency(root->operand_ids(0),
649 PostorderDFSNodeType::kConstantUpperBound, context)
650 .AddDependency(root->operand_ids(1),
651 PostorderDFSNodeType::kConstantLowerBound, context)
652 .AddVisit([root, opcode, this](
653 Literal upper_bound,
654 Literal lower_bound) -> StatusOr<Literal> {
655 if (opcode == HloOpcode::kDivide &&
656 this->IsValueEffectiveInteger(root->operand_ids(1))) {
657 // Because in many cases the lower bound of a value is
658 // integer 0, instead of throwing an divide-by-zero error
659 // at compile time, we set the bound defer the check to
660 // runtime. In those cases we use the upper-bound of
661 // first operand as a placeholder.
662 HloEvaluator evaluator;
663 auto zero = LiteralUtil::Zero(lower_bound.shape().element_type());
664 zero = zero.Broadcast(lower_bound.shape(), {}).ValueOrDie();
665 TF_ASSIGN_OR_RETURN(
666 auto lower_bound_is_zero,
667 evaluator.EvaluateElementwiseCompareOp(
668 ComparisonDirection::kEq, lower_bound, zero));
669
670 auto one = LiteralUtil::One(lower_bound.shape().element_type());
671 one = one.Broadcast(lower_bound.shape(), {}).ValueOrDie();
672 TF_ASSIGN_OR_RETURN(
673 lower_bound, evaluator.EvaluateElementwiseTernaryOp(
674 HloOpcode::kSelect, lower_bound_is_zero, one,
675 lower_bound));
676 }
677 std::vector<Literal> new_operands;
678 new_operands.emplace_back(std::move(upper_bound));
679 new_operands.emplace_back(std::move(lower_bound));
680 return HloProtoEvaluator(*root)
681 .WithOperands(absl::MakeSpan(new_operands))
682 .Evaluate();
683 });
684 }
685 case HloOpcode::kCustomCall: {
686 if (root->custom_call_target() == "SetBound") {
687 return PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
688 if (root->literal().shape().element_type() == TUPLE) {
689 // First literal of SetBound contains bounds, second literal
690 // contains dynamism indicators.
691 return Literal::CreateFromProto(root->literal().tuple_literals(0));
692 } else {
693 return Literal::CreateFromProto(root->literal());
694 }
695 });
696 } else if (root->custom_call_target() == "Sharding") {
697 return PostorderDFSNode()
698 .AddDependency(root->operand_ids(0),
699 PostorderDFSNodeType::kConstantUpperBound, context)
700 .AddVisit([](Literal operand) { return operand; });
701 }
702 return InvalidArgument(
703 "Upper-bound inferencing on custom call %s is not supported",
704 root->DebugString());
705 }
706 default:
707 return AnalyzeConstantValueFallback(
708 handle, PostorderDFSNodeType::kConstantUpperBound, context);
709 }
710 }
711
AnalyzeLowerBound(int64_t handle,InferenceContext context)712 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeLowerBound(
713 int64_t handle, InferenceContext context) {
714 const HloInstructionProto* root = handle_to_instruction(handle);
715 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
716 switch (opcode) {
717 case HloOpcode::kGetDimensionSize: {
718 int64_t dimension = root->dimensions(0);
719 int64_t operand_handle = root->operand_ids(0);
720 const HloInstructionProto* operand_proto =
721 handle_to_instruction(operand_handle);
722 return PostorderDFSNode().AddVisit(
723 [dimension, operand_proto]() -> StatusOr<Literal> {
724 if (operand_proto->shape().is_dynamic_dimension(dimension)) {
725 return LiteralUtil::CreateR0<int32>(0);
726 } else {
727 return LiteralUtil::CreateR0<int32>(
728 operand_proto->shape().dimensions(dimension));
729 }
730 });
731 }
732 case HloOpcode::kAbs: {
733 // lower-bound(abs(operand)) = min(abs(lower-bound(operand)),
734 // abs(upper-bound(operand)))
735 return PostorderDFSNode()
736 .AddDependency(root->operand_ids(0),
737 PostorderDFSNodeType::kConstantLowerBound, context)
738 .AddDependency(root->operand_ids(0),
739 PostorderDFSNodeType::kConstantUpperBound, context)
740 .AddVisit([](Literal lower_bound,
741 Literal upper_bound) -> StatusOr<Literal> {
742 HloEvaluator evaluator;
743 TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
744 evaluator.EvaluateElementwiseUnaryOp(
745 HloOpcode::kAbs, lower_bound));
746 TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
747 evaluator.EvaluateElementwiseUnaryOp(
748 HloOpcode::kAbs, upper_bound));
749 return evaluator.EvaluateElementwiseBinaryOp(
750 HloOpcode::kMinimum, lower_bound_abs, upper_bound_abs);
751 });
752 }
753 case HloOpcode::kNegate: {
754 // lower-bound(negate(operand)) = negate(upper-bound(operand))
755 return PostorderDFSNode()
756 .AddDependency(root->operand_ids(0),
757 PostorderDFSNodeType::kConstantUpperBound, context)
758 .AddVisit([](Literal upper_bound) -> StatusOr<Literal> {
759 HloEvaluator evaluator;
760 return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
761 upper_bound);
762 });
763 }
764 case HloOpcode::kSubtract:
765 case HloOpcode::kDivide: {
766 // Upper bound is used for second operand of subtract and divide.
767 return PostorderDFSNode()
768 .AddDependency(root->operand_ids(0),
769 PostorderDFSNodeType::kConstantLowerBound, context)
770 .AddDependency(root->operand_ids(1),
771 PostorderDFSNodeType::kConstantUpperBound, context)
772 .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
773 return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
774 });
775 }
776 default:
777 return AnalyzeConstantValueFallback(
778 handle, PostorderDFSNodeType::kConstantLowerBound, context);
779 }
780 }
781
AnalyzeConstant(int64_t handle,InferenceContext context)782 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstant(
783 int64_t handle, InferenceContext context) {
784 const HloInstructionProto* root = handle_to_instruction(handle);
785 HloOpcode opcode = StringToHloOpcode(root->opcode()).ValueOrDie();
786 switch (opcode) {
787 case HloOpcode::kGetDimensionSize: {
788 int64_t dimension = root->dimensions(0);
789 int64_t operand_handle = root->operand_ids(0);
790 const HloInstructionProto* operand_proto =
791 handle_to_instruction(operand_handle);
792 return PostorderDFSNode().AddVisit(
793 [operand_proto, dimension, root]() -> StatusOr<Literal> {
794 if (operand_proto->shape().is_dynamic_dimension(dimension)) {
795 // The value is dynamic, we return garbage data here and mask them
796 // out later.
797 return CreateGarbageLiteral(Shape(root->shape()));
798 } else {
799 return LiteralUtil::CreateR0<int32>(
800 operand_proto->shape().dimensions(dimension));
801 }
802 });
803 }
804 case HloOpcode::kSubtract:
805 case HloOpcode::kCos:
806 case HloOpcode::kSin:
807 case HloOpcode::kNegate:
808 case HloOpcode::kAbs:
809 case HloOpcode::kDivide: {
810 PostorderDFSNode result;
811 for (auto operand_id : root->operand_ids()) {
812 result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
813 context);
814 }
815 return result.AddVisit(
816 [root](absl::Span<Literal> operands) -> StatusOr<Literal> {
817 return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
818 });
819 }
820 case HloOpcode::kCustomCall: {
821 if (root->custom_call_target() == "SetBound") {
822 // `SetBound` doesn't change the static value of a tensor, so forward
823 // the operand when analyzing static value.
824 return PostorderDFSNode()
825 .AddDependency(root->operand_ids(0),
826 PostorderDFSNodeType::kConstantValue, context)
827 .AddVisit(
828 [](Literal operand) -> StatusOr<Literal> { return operand; });
829 } else if (root->custom_call_target() == "Sharding") {
830 return PostorderDFSNode()
831 .AddDependency(root->operand_ids(0),
832 PostorderDFSNodeType::kConstantValue, context)
833 .AddVisit([](Literal operand) { return operand; });
834 } else {
835 return PostorderDFSNode().AddVisit(
836 [root, context](absl::Span<Literal>) {
837 // The value is dynamic. We return a garbage literal here, which
838 // will be masked out later.
839 return CreateGarbageLiteral(ShapeUtil::GetSubshape(
840 Shape(root->shape()), context.shape_index));
841 });
842 }
843 }
844 case HloOpcode::kSort: {
845 PostorderDFSNode result;
846 InferenceContext dep_context = context;
847 dep_context.shape_index = {};
848 for (auto operand_id : root->operand_ids()) {
849 result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
850 dep_context);
851 }
852 const HloComputationProto* computation_proto =
853 handle_to_computation(root->called_computation_ids(0));
854 return result.AddVisit(
855 [root, context, computation_proto](
856 absl::Span<Literal> operands) -> StatusOr<Literal> {
857 TF_ASSIGN_OR_RETURN(
858 auto computation,
859 HloComputation::CreateFromProto(*computation_proto, {}));
860 return HloProtoEvaluator(*root)
861 .WithOperands(operands)
862 .WithComputation(std::move(computation))
863 .WithSubshape(context.shape_index)
864 .Evaluate();
865 });
866 }
867 default:
868 return AnalyzeConstantValueFallback(
869 handle, PostorderDFSNodeType::kConstantValue, context);
870 }
871 }
872
AnalyzeIsDynamic(int64_t handle,PostorderDFSNodeType type,InferenceContext context)873 StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeIsDynamic(
874 int64_t handle, PostorderDFSNodeType type, InferenceContext context) {
875 const HloInstructionProto* root = handle_to_instruction(handle);
876 // Invariant check.
877 TF_RET_CHECK(root);
878 VLOG(1) << "Analyzing IsDynamic on " << root->DebugString();
879 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
880 PostorderDFSNode result;
881 for (auto operand_id : root->operand_ids()) {
882 InferenceContext dep_context = context;
883 dep_context.shape_index = {};
884 result.AddDependency(operand_id, type, dep_context);
885 }
886 switch (opcode) {
887 case HloOpcode::kGetDimensionSize: {
888 int64_t dimension = root->dimensions(0);
889 int64_t operand_handle = root->operand_ids(0);
890 const HloInstructionProto* operand_proto =
891 handle_to_instruction(operand_handle);
892 return PostorderDFSNode().AddVisit(
893 [operand_proto, dimension, type]() -> StatusOr<Literal> {
894 if (type == PostorderDFSNodeType::kBoundIsDynamic) {
895 // The bound of dynamic dimension is not dynamic.
896 return LiteralUtil::CreateR0<bool>(false);
897 }
898 // The value of dynamic dimension is dynamic.
899 return LiteralUtil::CreateR0<bool>(
900 operand_proto->shape().is_dynamic_dimension(dimension));
901 });
902 }
903 case HloOpcode::kSort: {
904 auto dfs = PostorderDFSNode();
905 InferenceContext dep_context = context;
906 dep_context.shape_index = {};
907
908 for (int64_t i = 0; i < root->operand_ids_size(); ++i) {
909 dfs.AddDependency(root->operand_ids(i), type, dep_context);
910 }
911
912 return dfs.AddVisit([root, context, type](absl::Span<Literal> operands)
913 -> StatusOr<Literal> {
914 bool all_operands_values_static = true;
915 for (int64_t i = 0; i < operands.size(); ++i) {
916 all_operands_values_static &= operands[i].IsAll(0);
917 }
918 if (type == PostorderDFSNodeType::kValueIsDynamic) {
919 // If there is a single operand of a sort is dynamic, we
920 // conservatively say all results are dynamic.
921 return CreatePredLiteral(!all_operands_values_static,
922 ShapeUtil::GetSubshape(Shape(root->shape()),
923 context.shape_index));
924 }
925 CHECK(type == PostorderDFSNodeType::kBoundIsDynamic);
926 // The condition for bounds are more relaxed than values. If we know the
927 // bounds of each element [B0, B1... Bn], all results have the same
928 // bound
929 // [max(B0, B1...), max(B0, B1...), ...]
930 if (!context.shape_index.empty()) {
931 int64_t index = context.shape_index[0];
932 bool all_values_static = operands[index].IsAll(0);
933 return CreatePredLiteral(!all_values_static, operands[index].shape());
934 }
935
936 std::vector<Literal> results;
937 results.reserve(operands.size());
938 for (int64_t i = 0; i < operands.size(); ++i) {
939 bool all_values_static = operands[i].IsAll(0);
940 results.emplace_back(
941 CreatePredLiteral(!all_values_static, operands[i].shape()));
942 }
943 if (!ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index)
944 .IsTuple()) {
945 return std::move(results[0]);
946 }
947 return LiteralUtil::MakeTupleOwned(std::move(results));
948 });
949 }
950 case HloOpcode::kSetDimensionSize:
951 return result.AddVisit([root, type](absl::Span<Literal> operands) {
952 bool any_dynamic_operand = absl::c_any_of(
953 operands, [](Literal& operand) { return !operand.IsAll(0); });
954 // If values in a tensor `t` with bound are [e0, e1, e2...], we can say
955 // the max value of each position is [max(t), max(t), max(t), ...]. The
956 // effective size of this tensor doesn't change the max value.
957 return CreatePredLiteral(
958 type == PostorderDFSNodeType::kValueIsDynamic &&
959 any_dynamic_operand,
960 ShapeUtil::MakeStaticShape(Shape(root->shape())));
961 });
962 case HloOpcode::kDynamicSlice: {
963 return result.AddVisit([root](absl::Span<Literal> operands) {
964 // If any of the operand is dynamic, we say output is dynamic.
965 bool any_dynamic_operand = absl::c_any_of(
966 operands, [](Literal& operand) { return !operand.IsAll(0); });
967 return CreatePredLiteral(any_dynamic_operand, Shape(root->shape()));
968 });
969 }
970 case HloOpcode::kAbs:
971 case HloOpcode::kRoundNearestAfz:
972 case HloOpcode::kBitcast:
973 case HloOpcode::kCeil:
974 case HloOpcode::kCollectivePermuteDone:
975 case HloOpcode::kCos:
976 case HloOpcode::kClz:
977 case HloOpcode::kExp:
978 case HloOpcode::kExpm1:
979 case HloOpcode::kFloor:
980 case HloOpcode::kImag:
981 case HloOpcode::kIsFinite:
982 case HloOpcode::kLog:
983 case HloOpcode::kLog1p:
984 case HloOpcode::kNot:
985 case HloOpcode::kNegate:
986 case HloOpcode::kPopulationCount:
987 case HloOpcode::kReal:
988 case HloOpcode::kRsqrt:
989 case HloOpcode::kLogistic:
990 case HloOpcode::kSign:
991 case HloOpcode::kSin:
992 case HloOpcode::kConvert:
993 case HloOpcode::kSqrt:
994 case HloOpcode::kCbrt:
995 case HloOpcode::kTanh: {
996 // Forward operand as they don't change if a value is dynamic or static.
997 return result.AddVisit([](Literal operand) { return operand; });
998 }
999 case HloOpcode::kAdd:
1000 case HloOpcode::kAtan2:
1001 case HloOpcode::kDivide:
1002 case HloOpcode::kComplex:
1003 case HloOpcode::kMaximum:
1004 case HloOpcode::kMinimum:
1005 case HloOpcode::kMultiply:
1006 case HloOpcode::kPower:
1007 case HloOpcode::kRemainder:
1008 case HloOpcode::kSubtract:
1009 case HloOpcode::kCompare:
1010 case HloOpcode::kAnd:
1011 case HloOpcode::kOr:
1012 case HloOpcode::kXor:
1013 case HloOpcode::kShiftLeft:
1014 case HloOpcode::kShiftRightArithmetic:
1015 case HloOpcode::kShiftRightLogical: {
1016 return result.AddVisit([root](absl::Span<Literal> operands) {
1017 return HloProtoEvaluator(*root)
1018 .WithOperands(operands)
1019 .WithPrimitiveType(PRED)
1020 .WithOpCode(HloOpcode::kOr)
1021 .Evaluate();
1022 });
1023 }
1024 case HloOpcode::kTuple:
1025 case HloOpcode::kTranspose:
1026 case HloOpcode::kSlice:
1027 case HloOpcode::kBroadcast:
1028 case HloOpcode::kReverse:
1029 case HloOpcode::kConcatenate:
1030 case HloOpcode::kReshape:
1031 case HloOpcode::kPad: {
1032 if (opcode == HloOpcode::kTuple && !context.shape_index.empty()) {
1033 // There could be many operands of a tuple, but only one that we are
1034 // interested in, represented by `tuple_operand_index`.
1035 int64_t tuple_operand_index = context.shape_index.front();
1036 InferenceContext tuple_operand_context = context;
1037 tuple_operand_context.shape_index.pop_front();
1038 return PostorderDFSNode()
1039 .AddDependency(root->operand_ids(tuple_operand_index), type,
1040 tuple_operand_context)
1041 .AddVisit([](Literal operand) { return operand; });
1042 }
1043 return result.AddVisit([root](absl::Span<Literal> operands) {
1044 return HloProtoEvaluator(*root)
1045 .WithOperands(operands)
1046 .WithPrimitiveType(PRED)
1047 .Evaluate();
1048 });
1049 }
1050 case HloOpcode::kConditional: {
1051 auto node = PostorderDFSNode();
1052 auto* conditional_proto = root;
1053 InferenceContext predicate_context = context;
1054 predicate_context.shape_index = {};
1055 // Add dependencies to analyze the predicate of the conditional.
1056 node.AddDependency(conditional_proto->operand_ids(0),
1057 PostorderDFSNodeType::kConstantValue,
1058 predicate_context)
1059 .AddDependency(conditional_proto->operand_ids(0),
1060 PostorderDFSNodeType::kValueIsDynamic,
1061 predicate_context);
1062 const int64_t branch_size =
1063 conditional_proto->called_computation_ids_size();
1064 for (int64_t i = 0; i < branch_size; ++i) {
1065 int64_t branch_root =
1066 handle_to_computation(conditional_proto->called_computation_ids(i))
1067 ->root_id();
1068 InferenceContext branch_context = context;
1069 branch_context.caller_operand_handles.push_back(
1070 conditional_proto->operand_ids(i + 1));
1071 node.AddDependency(branch_root, PostorderDFSNodeType::kConstantValue,
1072 branch_context,
1073 absl::StrFormat("branch %lld's value", i))
1074 .AddDependency(branch_root, PostorderDFSNodeType::kValueIsDynamic,
1075 branch_context,
1076 absl::StrFormat("branch %lld's dynamism", i));
1077 }
1078 // Predicate uses 2 dependencies:
1079 // 0: Predicate value.
1080 // 1: Predicate is dynamic.
1081 // Each branch i has 2 dependenices:
1082 // 2*i: Branch result value
1083 // 2*i + 1: Branch value is dynamic.
1084 return node.AddVisit(
1085 [root, branch_size,
1086 context](absl::Span<Literal> operands) -> StatusOr<Literal> {
1087 int64_t pred_is_dynamic = operands[1].Get<bool>({});
1088 auto result = CreatePredLiteral(
1089 true, ShapeUtil::GetSubshape(Shape(root->shape()),
1090 context.shape_index));
1091 if (pred_is_dynamic) {
1092 VLOG(1) << "predict is dynamic value" << result.ToString();
1093 // If predicate is dynamic, the result is only static if all
1094 // branches are static and return the same value.
1095 result.MutableEachCell<bool>([&](absl::Span<const int64> indices,
1096 bool value) {
1097 string branch_value = operands[2].GetAsString(indices, {});
1098 for (int64_t i = 0; i < branch_size; ++i) {
1099 const int64_t branch_value_index = 2 + 2 * i;
1100 const int64_t branch_dynamism_index = 2 + 2 * i + 1;
1101 auto branch_is_dynamic =
1102 operands[branch_dynamism_index].Get<bool>(indices);
1103 if (branch_is_dynamic) {
1104 return true;
1105 }
1106
1107 if (branch_value !=
1108 operands[branch_value_index].GetAsString(indices, {})) {
1109 return true;
1110 }
1111 }
1112 // Value of the branch is static.
1113 return false;
1114 });
1115 return result;
1116 } else {
1117 VLOG(1) << "predict is constant value";
1118 // If predicate is static, return true if given branch result
1119 // value is dynamic.
1120 int64_t branch_index = 0;
1121 if (operands[0].shape().element_type() == PRED) {
1122 if (operands[0].Get<bool>({})) {
1123 branch_index = 0;
1124 } else {
1125 branch_index = 1;
1126 }
1127 } else {
1128 branch_index = operands[0].GetIntegralAsS64({}).value();
1129 }
1130 const int64_t branch_dynamism_index = 2 + 2 * branch_index + 1;
1131 return std::move(operands[branch_dynamism_index]);
1132 }
1133 });
1134 }
1135 case HloOpcode::kGetTupleElement: {
1136 int64_t operand_handle = root->operand_ids(0);
1137 PostorderDFSNode result;
1138 context.shape_index.push_front(root->tuple_index());
1139 return PostorderDFSNode()
1140 .AddDependency(operand_handle, type, context)
1141 .AddVisit([](Literal operand) { return operand; });
1142 }
1143
1144 case HloOpcode::kReduce: {
1145 return result.AddVisit([root, context](absl::Span<Literal> operands) {
1146 Shape root_shape = Shape(root->shape());
1147 Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
1148 std::unique_ptr<HloComputation> reduce_or;
1149 if (root_shape.IsTuple()) {
1150 // Variadic reduce.
1151 HloComputation::Builder b("reduce_or");
1152 std::vector<HloInstruction*> results;
1153 results.reserve(root_shape.tuple_shapes_size());
1154 for (int64_t i = 0; i < root_shape.tuple_shapes_size(); ++i) {
1155 auto lhs = b.AddInstruction(
1156 HloInstruction::CreateParameter(i, scalar_shape, "lhs"));
1157 auto rhs = b.AddInstruction(HloInstruction::CreateParameter(
1158 i + root_shape.tuple_shapes_size(), scalar_shape, "rhs"));
1159 results.push_back(b.AddInstruction(HloInstruction::CreateBinary(
1160 scalar_shape, HloOpcode::kOr, lhs, rhs)));
1161 }
1162 b.AddInstruction(HloInstruction::CreateTuple(results));
1163 reduce_or = b.Build();
1164 } else {
1165 HloComputation::Builder b("reduce_or");
1166 auto lhs = b.AddInstruction(
1167 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1168 auto rhs = b.AddInstruction(
1169 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1170 b.AddInstruction(HloInstruction::CreateBinary(
1171 scalar_shape, HloOpcode::kOr, lhs, rhs));
1172 reduce_or = b.Build();
1173 }
1174
1175 return HloProtoEvaluator(*root)
1176 .WithOperands(operands)
1177 .WithPrimitiveType(PRED)
1178 .WithComputation(std::move(reduce_or))
1179 // Reduce could produce tuple shape, only fetch what we need.
1180 .WithSubshape(context.shape_index)
1181 .Evaluate();
1182 });
1183 }
1184 case HloOpcode::kConstant:
1185 case HloOpcode::kIota: {
1186 return result.AddVisit(
1187 [root]() { return CreatePredLiteral(false, Shape(root->shape())); });
1188 }
1189 case HloOpcode::kParameter: {
1190 if (opcode == HloOpcode::kParameter &&
1191 !context.caller_operand_handles.empty()) {
1192 int64_t caller_operand = context.caller_operand_handles.back();
1193 context.caller_operand_handles.pop_back();
1194 return result.AddDependency(caller_operand, type, context)
1195 .AddVisit([](Literal literal) { return literal; });
1196 }
1197 return result.AddVisit([root, context]() {
1198 return CreatePredLiteral(
1199 true,
1200 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1201 });
1202 }
1203 case HloOpcode::kSelect: {
1204 return PostorderDFSNode()
1205 .AddDependency(root->operand_ids(0),
1206 PostorderDFSNodeType::kConstantValue, context)
1207 .AddDependency(root->operand_ids(0),
1208 PostorderDFSNodeType::kValueIsDynamic, context)
1209 // lhs dependency.
1210 .AddDependency(root->operand_ids(1), type, context)
1211 // rhs dependency.
1212 .AddDependency(root->operand_ids(2), type, context)
1213 .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
1214 OptionalLiteral optional_selector_literal(std::move(operands[0]),
1215 std::move(operands[1]));
1216 Literal lhs = std::move(operands[2]);
1217 Literal rhs = std::move(operands[3]);
1218 auto result = CreatePredLiteral(true, Shape(root->shape()));
1219 result.MutableEachCell<bool>(
1220 [&](absl::Span<const int64> indices, bool value) {
1221 absl::optional<bool> optional_selector =
1222 optional_selector_literal.Get<bool>(indices);
1223
1224 bool lhs_value = lhs.Get<bool>(indices);
1225 bool rhs_value = rhs.Get<bool>(indices);
1226 if (optional_selector.has_value()) {
1227 // Manually evaluate the selection without using Evaluator.
1228 if (*optional_selector) {
1229 return lhs_value;
1230 } else {
1231 return rhs_value;
1232 }
1233 } else {
1234 // Conservatively assume value is dynamic if selector is
1235 // dynamic.
1236 return true;
1237 }
1238 });
1239 return result;
1240 });
1241 }
1242 case HloOpcode::kGather: {
1243 return PostorderDFSNode()
1244 .AddDependency(root->operand_ids(0), type, context)
1245 .AddDependency(root->operand_ids(1),
1246 PostorderDFSNodeType::kConstantValue, context)
1247 .AddDependency(root->operand_ids(1),
1248 PostorderDFSNodeType::kValueIsDynamic, context)
1249 .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
1250 OptionalLiteral optional_selector_literal(std::move(operands[1]),
1251 std::move(operands[2]));
1252
1253 if (!optional_selector_literal.AllValid()) {
1254 // Conservatively assume results are dynamic.
1255 return CreatePredLiteral(true, Shape(root->shape()));
1256 }
1257 std::vector<Literal> new_operands;
1258 new_operands.emplace_back(std::move(operands[0]));
1259 new_operands.emplace_back(
1260 optional_selector_literal.GetValue()->Clone());
1261
1262 return HloProtoEvaluator(*root)
1263 .WithOperands(absl::MakeSpan(new_operands))
1264 .WithPrimitiveType(PRED)
1265 .Evaluate();
1266 });
1267 }
1268 case HloOpcode::kCustomCall: {
1269 if (root->custom_call_target() == "SetBound") {
1270 return PostorderDFSNode().AddVisit([type, root]() -> StatusOr<Literal> {
1271 if (type == PostorderDFSNodeType::kBoundIsDynamic) {
1272 return CreatePredLiteral(false, Shape(root->shape()));
1273 } else {
1274 if (root->literal().shape().element_type() == TUPLE) {
1275 // First literal of SetBound contains bounds, second literal
1276 // contains dynamism indicators.
1277 return Literal::CreateFromProto(
1278 root->literal().tuple_literals(1));
1279 } else {
1280 return Literal::CreateFromProto(root->literal());
1281 }
1282 }
1283 });
1284 } else if (root->custom_call_target() == "Sharding") {
1285 return result.AddVisit([](Literal operand) { return operand; });
1286 } else {
1287 return InvalidArgument(
1288 "Dynamic inferencing on custom call %s is not supported",
1289 root->DebugString());
1290 }
1291
1292 break;
1293 }
1294
1295 case HloOpcode::kCall:
1296 case HloOpcode::kRecv:
1297 case HloOpcode::kRecvDone:
1298 case HloOpcode::kSend:
1299 case HloOpcode::kSendDone:
1300 case HloOpcode::kWhile: {
1301 return PostorderDFSNode().AddVisit([root,
1302 context]() -> StatusOr<Literal> {
1303 return CreatePredLiteral(
1304 true,
1305 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1306 });
1307 break;
1308 }
1309 default:
1310 return PostorderDFSNode().AddVisit([root,
1311 context]() -> StatusOr<Literal> {
1312 return CreatePredLiteral(
1313 true,
1314 ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
1315 });
1316 }
1317 }
1318
PostOrderDFSVisit(int64_t handle,PostorderDFSNodeType type)1319 StatusOr<Literal> PostorderDFSVisitor::PostOrderDFSVisit(
1320 int64_t handle, PostorderDFSNodeType type) {
1321 enum VisitState {
1322 kUnvisited = 0,
1323 kVisiting,
1324 kVisited,
1325 };
1326
1327 int64_t unique_id = 0;
1328 struct WorkItem {
1329 explicit WorkItem(int64_t handle, InferenceContext context,
1330 PostorderDFSNodeType type, VisitState state, int64_t id)
1331 : handle(handle),
1332 context(std::move(context)),
1333 type(type),
1334 state(state),
1335 id(id) {}
1336 int64 handle; // Handle of the node in the graph.
1337 InferenceContext context;
1338 PostorderDFSNodeType type;
1339 VisitState state;
1340 Visit visit; // The handler to call once the dependencies are resolved into
1341 // literal form.
1342 int64 id; // Unique id in the work queue, starting from 0.
1343 std::vector<CacheKey> dependencies;
1344
1345 CacheKey GetCacheKey() { return CacheKey(handle, context, type); }
1346 };
1347
1348 std::vector<WorkItem> stack;
1349 WorkItem root(handle, InferenceContext({}, {}), type, kUnvisited,
1350 unique_id++);
1351 stack.push_back(root);
1352 while (!stack.empty()) {
1353 WorkItem& item = stack.back();
1354 VLOG(1) << "stack top shape index: " << item.context.shape_index.ToString();
1355 VLOG(1) << "stack top "
1356 << handle_to_instruction(item.handle)->DebugString();
1357 if (item.state == kVisiting) {
1358 VLOG(1) << "visiting";
1359 // The dependencies are ready, visit the node itself.
1360
1361 // Gather dependencies and transform them into literals.
1362 std::vector<Literal> literals;
1363 for (CacheKey& dep_key : item.dependencies) {
1364 TF_RET_CHECK(evaluated.contains(dep_key));
1365 literals.emplace_back(evaluated.at(dep_key).Clone());
1366 }
1367 VLOG(1) << "Start visiting with dependency type: "
1368 << PostorderDFSNodeTypeToString(item.type);
1369 TF_ASSIGN_OR_RETURN(auto literal, item.visit(absl::MakeSpan(literals)));
1370 VLOG(1) << "End visiting: " << literal.ToString();
1371 evaluated[item.GetCacheKey()] = std::move(literal);
1372 stack.pop_back();
1373 continue;
1374 }
1375 // This is the first time we see this node, we want to gather its
1376 // dependenceis.
1377 VLOG(1) << "unvisited";
1378 if (evaluated.contains(item.GetCacheKey())) {
1379 stack.pop_back();
1380 continue;
1381 }
1382 item.state = kVisiting;
1383 PostorderDFSNode node;
1384 switch (item.type) {
1385 case PostorderDFSNodeType::kConstantValue: {
1386 VLOG(1) << "constant value";
1387 TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle, item.context));
1388 break;
1389 }
1390 case PostorderDFSNodeType::kConstantLowerBound: {
1391 VLOG(1) << "constant lower bound";
1392 TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle, item.context));
1393 break;
1394 }
1395 case PostorderDFSNodeType::kConstantUpperBound: {
1396 VLOG(1) << "constant upper bound";
1397 TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle, item.context));
1398 break;
1399 }
1400 case PostorderDFSNodeType::kBoundIsDynamic:
1401 case PostorderDFSNodeType::kValueIsDynamic: {
1402 VLOG(1) << "value is dynamic";
1403 TF_ASSIGN_OR_RETURN(
1404 node, AnalyzeIsDynamic(item.handle, item.type, item.context));
1405 break;
1406 }
1407 }
1408 // Store the visit function which is needed when its dependencies are
1409 // resolved.
1410 item.visit = node.visit;
1411
1412 const int64 current_item_id = stack.size() - 1;
1413 // Enqueue dependencies into the stack. `item` shouldn't be accessed after
1414 // this point.
1415 for (const PostorderDFSDep& dep : node.dependencies) {
1416 VLOG(1) << "dep " << dep.annotation
1417 << "::" << handle_to_instruction(dep.handle)->DebugString()
1418 << "index" << dep.context.shape_index
1419 << " stack size:" << stack.size();
1420 stack.emplace_back(dep.handle, dep.context, dep.type, kUnvisited,
1421 unique_id++);
1422 stack[current_item_id].dependencies.push_back(stack.back().GetCacheKey());
1423 }
1424 }
1425 VLOG(1) << "done" << evaluated[root.GetCacheKey()].ToString();
1426 return evaluated[root.GetCacheKey()].Clone();
1427 }
1428
AnalyzeIsDynamic(XlaOp op)1429 StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
1430 PostorderDFSVisitor visitor(
1431 [&](int64_t handle) {
1432 return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1433 },
1434 [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1435 auto result = visitor.PostOrderDFSVisit(
1436 op.handle(), PostorderDFSNodeType::kValueIsDynamic);
1437 return result;
1438 }
1439
CseOpHandle(int64_t handle)1440 absl::optional<int64> ValueInference::CseOpHandle(int64_t handle) {
1441 auto inst = builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1442 HloOpcode opcode = StringToHloOpcode(inst->opcode()).ValueOrDie();
1443 // For now, only handle kGetDimensionSize as that's the most duplicated one.
1444 if (opcode != HloOpcode::kGetDimensionSize) {
1445 return absl::nullopt;
1446 }
1447 int64_t hash = inst->operand_ids(0);
1448 hash =
1449 tensorflow::Hash64Combine(hash, std::hash<int64>()(inst->dimensions(0)));
1450 auto lookup = cse_map_.find(hash);
1451 if (lookup == cse_map_.end()) {
1452 cse_map_[hash] = handle;
1453 return absl::nullopt;
1454 }
1455 auto equivalent_op =
1456 builder_->LookUpInstructionByHandle(lookup->second).ValueOrDie();
1457 // Check that the op is indeed equivalent to prevent hash collision --
1458 // relatively easy to happen with 64 bits hash.
1459 if (equivalent_op->opcode() != inst->opcode() ||
1460 equivalent_op->operand_ids(0) != inst->operand_ids(0) ||
1461 equivalent_op->dimensions(0) != inst->dimensions(0)) {
1462 // Hash collision, don't CSE.
1463 return absl::nullopt;
1464 }
1465 int64_t cse = lookup->second;
1466 if (handle != cse) {
1467 // Successfully found a handle that's not the same as input but equivalent.
1468 return cse;
1469 }
1470 return absl::nullopt;
1471 }
1472
SimplifyOp(int64_t handle)1473 StatusOr<Literal> ValueInference::SimplifyOp(int64_t handle) {
1474 if (auto cse_handle = CseOpHandle(handle)) {
1475 // Use the CSE'd handle instead.
1476 return SimplifyOp(*cse_handle);
1477 }
1478 auto inst = builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1479 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode()));
1480 std::vector<Literal> operands;
1481 auto output_shape = Shape(inst->shape());
1482 switch (opcode) {
1483 case HloOpcode::kSlice:
1484 case HloOpcode::kConcatenate:
1485 case HloOpcode::kReshape:
1486 case HloOpcode::kBroadcast: {
1487 for (auto operand_id : inst->operand_ids()) {
1488 TF_ASSIGN_OR_RETURN(auto literal, SimplifyOp(operand_id));
1489 operands.emplace_back(std::move(literal));
1490 }
1491 // We put handles into the tensor and evaluate the results into a literal.
1492 // The literal also contain handles for each element position.
1493 return HloProtoEvaluator(*inst)
1494 .WithOperands(absl::MakeSpan(operands))
1495 .WithPrimitiveType(S64)
1496 .Evaluate();
1497 }
1498 case HloOpcode::kConvert: {
1499 // Only identity kConvert can be optimized away.
1500 auto operand = builder_->LookUpInstructionByHandle(inst->operand_ids(0))
1501 .ValueOrDie();
1502 if (Shape::Equal()(output_shape, Shape(operand->shape()))) {
1503 // Forward operand handle as result.
1504 return SimplifyOp(inst->operand_ids(0));
1505 } else {
1506 return CreateS64Literal(-1, output_shape);
1507 }
1508 }
1509 case HloOpcode::kAdd: {
1510 // a + (b - a) => b
1511 // a + b + (c - a) => b + c
1512 if (output_shape.rank() == 0) {
1513 TF_ASSIGN_OR_RETURN(auto lhs, SimplifyOp(inst->operand_ids(0)));
1514 TF_ASSIGN_OR_RETURN(auto rhs, SimplifyOp(inst->operand_ids(1)));
1515 int64_t lhs_handle = lhs.Get<int64>({});
1516 int64_t rhs_handle = rhs.Get<int64>({});
1517 if (lhs_handle == -1 || rhs_handle == -1) {
1518 return CreateS64Literal(-1, output_shape);
1519 }
1520 // Recursive lambda needs explicit signature.
1521 std::function<absl::optional<int64>(int64_t, int64_t)> can_be_optimized;
1522 can_be_optimized = [this, &can_be_optimized](
1523 int64_t lhs,
1524 int64_t rhs) -> absl::optional<int64> {
1525 auto rhs_inst = builder_->LookUpInstructionByHandle(rhs).ValueOrDie();
1526 HloOpcode rhs_opcode =
1527 StringToHloOpcode(rhs_inst->opcode()).ValueOrDie();
1528 if (rhs_opcode == HloOpcode::kSubtract) {
1529 auto sub_lhs_handle = SimplifyOp(rhs_inst->operand_ids(0))
1530 .ValueOrDie()
1531 .Get<int64>({});
1532 auto sub_rhs_handle = SimplifyOp(rhs_inst->operand_ids(1))
1533 .ValueOrDie()
1534 .Get<int64>({});
1535 if (sub_rhs_handle == lhs) {
1536 // lhs + (sub_lhs - sub_rhs) = sub_lhs if lhs == sub_rhs
1537 return sub_lhs_handle;
1538 }
1539 }
1540
1541 // Check the case for a + b + (c - a) => b + c
1542 auto lhs_inst = builder_->LookUpInstructionByHandle(lhs).ValueOrDie();
1543 HloOpcode lhs_opcode =
1544 StringToHloOpcode(lhs_inst->opcode()).ValueOrDie();
1545 if (lhs_opcode == HloOpcode::kAdd) {
1546 auto add_lhs_handle = SimplifyOp(lhs_inst->operand_ids(0))
1547 .ValueOrDie()
1548 .Get<int64>({});
1549 auto add_rhs_handle = SimplifyOp(lhs_inst->operand_ids(1))
1550 .ValueOrDie()
1551 .Get<int64>({});
1552 if (auto optimized = can_be_optimized(add_lhs_handle, rhs)) {
1553 return Add(XlaOp(add_rhs_handle, builder_),
1554 XlaOp(optimized.value(), builder_))
1555 .handle();
1556 }
1557 if (auto optimized = can_be_optimized(add_rhs_handle, rhs)) {
1558 return Add(XlaOp(add_lhs_handle, builder_),
1559 XlaOp(optimized.value(), builder_))
1560 .handle();
1561 }
1562 }
1563 return absl::nullopt;
1564 };
1565 if (auto optimized = can_be_optimized(lhs_handle, rhs_handle)) {
1566 return LiteralUtil::CreateR0<int64>(optimized.value());
1567 }
1568 // Swap lhs and rhs.
1569 if (auto optimized = can_be_optimized(rhs_handle, lhs_handle)) {
1570 return LiteralUtil::CreateR0<int64>(optimized.value());
1571 }
1572 // This sum can't be optimized, return sum of lhs and rhs. Note that we
1573 // can't just return the original sum as its lhs and rhs could be
1574 // optimized and different.
1575 XlaOp new_sum =
1576 Add(XlaOp(lhs_handle, builder_), XlaOp(rhs_handle, builder_));
1577
1578 return LiteralUtil::CreateR0<int64>(new_sum.handle());
1579 } else {
1580 return CreateS64Literal(-1, output_shape);
1581 }
1582 }
1583 default: {
1584 if (ShapeUtil::IsScalar(output_shape)) {
1585 return LiteralUtil::CreateR0<int64>(handle);
1586 } else {
1587 return CreateS64Literal(-1, output_shape);
1588 }
1589 }
1590 }
1591 }
1592
AnalyzeConstant(XlaOp op,ValueInferenceMode mode)1593 StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
1594 XlaOp op, ValueInferenceMode mode) {
1595 PostorderDFSVisitor visitor(
1596 [&](int64_t handle) {
1597 return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
1598 },
1599 [&](int64_t handle) { return &(builder_->embedded_[handle]); });
1600 int64_t handle = op.handle();
1601 if (ShapeUtil::IsScalar(builder_->GetShape(op).ValueOrDie())) {
1602 TF_ASSIGN_OR_RETURN(auto result, SimplifyOp(handle));
1603 auto optimized_handle = result.Get<int64>({});
1604 if (optimized_handle != -1) {
1605 handle = optimized_handle;
1606 }
1607 }
1608 switch (mode) {
1609 case ValueInferenceMode::kLowerBound: {
1610 TF_ASSIGN_OR_RETURN(
1611 Literal value,
1612 visitor.PostOrderDFSVisit(handle,
1613 PostorderDFSNodeType::kConstantLowerBound));
1614 TF_ASSIGN_OR_RETURN(Literal mask,
1615 visitor.PostOrderDFSVisit(
1616 handle, PostorderDFSNodeType::kBoundIsDynamic));
1617 return OptionalLiteral(std::move(value), std::move(mask));
1618 }
1619 case ValueInferenceMode::kUpperBound: {
1620 TF_ASSIGN_OR_RETURN(
1621 Literal value,
1622 visitor.PostOrderDFSVisit(handle,
1623 PostorderDFSNodeType::kConstantUpperBound));
1624 TF_ASSIGN_OR_RETURN(Literal mask,
1625 visitor.PostOrderDFSVisit(
1626 handle, PostorderDFSNodeType::kBoundIsDynamic));
1627
1628 return OptionalLiteral(std::move(value), std::move(mask));
1629 }
1630 case ValueInferenceMode::kValue: {
1631 TF_ASSIGN_OR_RETURN(Literal value,
1632 visitor.PostOrderDFSVisit(
1633 handle, PostorderDFSNodeType::kConstantValue));
1634 TF_ASSIGN_OR_RETURN(Literal mask,
1635 visitor.PostOrderDFSVisit(
1636 handle, PostorderDFSNodeType::kValueIsDynamic));
1637 return OptionalLiteral(std::move(value), std::move(mask));
1638 }
1639 }
1640 }
1641
1642 } // namespace xla
1643