1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 namespace xla {
22 namespace {
23
ReconsileBranchDifference(const Shape & left_branch_shape,const Shape & right_branch_shape,XlaOp left_root)24 XlaOp ReconsileBranchDifference(const Shape& left_branch_shape,
25 const Shape& right_branch_shape,
26 XlaOp left_root) {
27 if (left_branch_shape.IsTuple()) {
28 // Invariant sanity check -- Left branch and right branch need to have
29 // compatible shapes.
30 CHECK(right_branch_shape.IsTuple() &&
31 left_branch_shape.tuple_shapes_size() ==
32 right_branch_shape.tuple_shapes_size());
33 // Recurse into sub-element.
34 std::vector<XlaOp> results;
35 results.reserve(left_branch_shape.tuple_shapes_size());
36 for (int64 i = 0; i < left_branch_shape.tuple_shapes_size(); ++i) {
37 XlaOp sub_tuple = GetTupleElement(left_root, i);
38 XlaOp elem = ReconsileBranchDifference(left_branch_shape.tuple_shapes(i),
39 right_branch_shape.tuple_shapes(i),
40 sub_tuple);
41 results.push_back(elem);
42 }
43 return Tuple(left_root.builder(), results);
44 }
45 XlaOp result = left_root;
46 // Invariant sanity check -- Left branch and right branch need to have
47 // compatible shapes.
48 CHECK(!right_branch_shape.IsTuple());
49 CHECK(left_branch_shape.rank() == right_branch_shape.rank());
50 for (int64 dim = 0; dim < left_branch_shape.rank(); ++dim) {
51 XlaOp original_dim = GetDimensionSize(result, dim);
52 if (left_branch_shape.dimensions(dim) <
53 right_branch_shape.dimensions(dim)) {
54 int64 diff = right_branch_shape.dimensions(dim) -
55 left_branch_shape.dimensions(dim);
56
57 result = PadInDim(
58 result, Zero(result.builder(), left_branch_shape.element_type()), dim,
59 0, diff);
60 }
61 if (left_branch_shape.dimensions(dim) !=
62 right_branch_shape.dimensions(dim)) {
63 result = SetDimensionSize(result, original_dim, dim);
64 }
65 }
66 return result;
67 }
68 } // namespace
DynamicConditional(XlaBuilder * builder,XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)69 XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate,
70 XlaOp true_operand,
71 const XlaComputation& true_computation,
72 XlaOp false_operand,
73 const XlaComputation& false_computation) {
74 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75 auto true_shape =
76 true_computation.GetProgramShape().ConsumeValueOrDie().result();
77
78 auto false_shape =
79 false_computation.GetProgramShape().ConsumeValueOrDie().result();
80
81 if (ShapeUtil::Compatible(true_shape, false_shape)) {
82 return xla::Conditional(predicate, true_operand, true_computation,
83 false_operand, false_computation);
84 }
85
86 auto reconsile_branch = [](const Shape& root_shape,
87 const Shape& operand_shape,
88 const Shape& reference_root_shape,
89 const XlaComputation& computation) {
90 xla::XlaBuilder builder("dynamic_builder");
91 auto param = xla::Parameter(&builder, 0, operand_shape, "param");
92 auto call = Call(&builder, computation, {param});
93
94 ReconsileBranchDifference(root_shape, reference_root_shape, call);
95 return builder.Build();
96 };
97 TF_ASSIGN_OR_RETURN(
98 auto true_computation_rewritten,
99 reconsile_branch(true_shape,
100 builder->GetShape(true_operand).ValueOrDie(),
101 false_shape, true_computation));
102
103 TF_ASSIGN_OR_RETURN(
104 auto false_computation_rewritten,
105 reconsile_branch(false_shape,
106 builder->GetShape(false_operand).ValueOrDie(),
107 true_shape, false_computation));
108 return xla::Conditional(predicate, true_operand, true_computation_rewritten,
109 false_operand, false_computation_rewritten);
110 });
111 }
112
SetDimensionSizeWithRebound(ValueInference * value_inference,XlaOp operand,XlaOp dimension_size,int64_t dimension)113 StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
114 XlaOp operand, XlaOp dimension_size,
115 int64_t dimension) {
116 auto inferred_bound_status_or = value_inference->AnalyzeConstant(
117 dimension_size, xla::ValueInferenceMode::kUpperBound);
118 TF_RETURN_IF_ERROR(inferred_bound_status_or.status());
119 if (inferred_bound_status_or->AllValid()) {
120 int64 inferred_bound = inferred_bound_status_or->Get<int32>({}).value();
121 TF_ASSIGN_OR_RETURN(auto* shape_ptr,
122 operand.builder()->GetShapePtr(operand));
123 // Found a tighter bound, do a slice.
124 if (shape_ptr->dimensions(dimension) > inferred_bound)
125 operand = xla::SliceInDim(operand, 0, inferred_bound, 1, dimension);
126 }
127 operand = xla::SetDimensionSize(operand, dimension_size, dimension);
128 return operand;
129 }
130 } // namespace xla
131