1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28 namespace {
29
30 // Given a list of shapes, create a shape whose dimensions are largest among all
31 // inputs.
32 //
33 // e.g.,
34 // shape_a = f32[10, 50]
35 // shape_b = f32[100, 10]
36 //
37 // result = f32[max(shape_a[0], shape_b[0]), max(shape_a[1], shape_b[1])]
38 // = f32[100, 50]
FindMaxShape(absl::Span<const Shape * > shapes)39 Shape FindMaxShape(absl::Span<const Shape*> shapes) {
40 CHECK(!shapes.empty());
41 if (shapes[0]->IsTuple()) {
42 // Recurse into sub-element.
43 std::vector<Shape> results;
44 results.reserve(shapes[0]->tuple_shapes_size());
45 for (int i = 0; i < shapes[0]->tuple_shapes_size(); ++i) {
46 std::vector<const Shape*> subshapes;
47 subshapes.reserve(shapes.size());
48 for (int64_t j = 0; j < shapes.size(); ++j) {
49 subshapes.push_back(&shapes[j]->tuple_shapes(i));
50 }
51 results.push_back(FindMaxShape(absl::MakeSpan(subshapes)));
52 }
53 return ShapeUtil::MakeTupleShape(results);
54 }
55 Shape result = *shapes[0];
56
57 for (const Shape* shape : shapes) {
58 CHECK(result.rank() == shape->rank());
59 for (int64_t dim = 0; dim < result.rank(); ++dim) {
60 if (shape->dimensions(dim) > result.dimensions(dim)) {
61 result.set_dimensions(dim, shape->dimensions(dim));
62 }
63 }
64 }
65 return result;
66 }
67
ReconsileBranchDifference(const Shape & left_branch_shape,const Shape & right_branch_shape,XlaOp left_root)68 StatusOr<XlaOp> ReconsileBranchDifference(const Shape& left_branch_shape,
69 const Shape& right_branch_shape,
70 XlaOp left_root) {
71 if (left_branch_shape.IsTuple()) {
72 // Invariant sanity check -- Left branch and right branch need to have
73 // compatible shapes.
74 CHECK(right_branch_shape.IsTuple() &&
75 left_branch_shape.tuple_shapes_size() ==
76 right_branch_shape.tuple_shapes_size());
77 // Recurse into sub-element.
78 std::vector<XlaOp> results;
79 results.reserve(left_branch_shape.tuple_shapes_size());
80 for (int i = 0; i < left_branch_shape.tuple_shapes_size(); ++i) {
81 XlaOp sub_tuple = GetTupleElement(left_root, i);
82 TF_ASSIGN_OR_RETURN(XlaOp elem,
83 ReconsileBranchDifference(
84 left_branch_shape.tuple_shapes(i),
85 right_branch_shape.tuple_shapes(i), sub_tuple));
86 results.push_back(elem);
87 }
88 return Tuple(left_root.builder(), results);
89 }
90 XlaOp result = left_root;
91 // Invariant sanity check -- Left branch and right branch need to have
92 // compatible shapes.
93 if (right_branch_shape.IsTuple()) {
94 return InvalidArgument(
95 "right_branch_shape should not be a tuple, received %s",
96 right_branch_shape.DebugString());
97 }
98 if (left_branch_shape.rank() != right_branch_shape.rank()) {
99 return InvalidArgument(
100 "left_branch_shape.rank() != right_branch_shape.rank() (%d vs %d)",
101 left_branch_shape.rank(), right_branch_shape.rank());
102 }
103 for (int64_t dim = 0; dim < left_branch_shape.rank(); ++dim) {
104 XlaOp original_dim = GetDimensionSize(result, dim);
105 if (left_branch_shape.dimensions(dim) <
106 right_branch_shape.dimensions(dim)) {
107 int64_t diff = right_branch_shape.dimensions(dim) -
108 left_branch_shape.dimensions(dim);
109
110 result = PadInDim(
111 result, Zero(result.builder(), left_branch_shape.element_type()), dim,
112 0, diff);
113 }
114 if (left_branch_shape.dimensions(dim) !=
115 right_branch_shape.dimensions(dim)) {
116 result = SetDimensionSize(result, original_dim, dim);
117 }
118 }
119 return result;
120 }
121 } // namespace
DynamicConditional(XlaBuilder * builder,XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)122 XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate,
123 XlaOp true_operand,
124 const XlaComputation& true_computation,
125 XlaOp false_operand,
126 const XlaComputation& false_computation) {
127 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
128 auto true_shape = true_computation.GetProgramShape().value().result();
129
130 auto false_shape = false_computation.GetProgramShape().value().result();
131
132 if (ShapeUtil::Compatible(true_shape, false_shape)) {
133 return xla::Conditional(predicate, true_operand, true_computation,
134 false_operand, false_computation);
135 }
136
137 auto reconsile_branch =
138 [](const Shape& root_shape, const Shape& operand_shape,
139 const Shape& reference_root_shape,
140 const XlaComputation& computation) -> StatusOr<XlaComputation> {
141 xla::XlaBuilder builder("dynamic_builder");
142 auto param = xla::Parameter(&builder, 0, operand_shape, "param");
143 auto call = Call(&builder, computation, {param});
144
145 auto elem =
146 ReconsileBranchDifference(root_shape, reference_root_shape, call);
147 if (!elem.ok()) return elem.status();
148 return builder.Build();
149 };
150 TF_ASSIGN_OR_RETURN(
151 auto true_computation_rewritten,
152 reconsile_branch(true_shape,
153 builder->GetShape(true_operand).ValueOrDie(),
154 false_shape, true_computation));
155
156 TF_ASSIGN_OR_RETURN(
157 auto false_computation_rewritten,
158 reconsile_branch(false_shape,
159 builder->GetShape(false_operand).ValueOrDie(),
160 true_shape, false_computation));
161 return xla::Conditional(predicate, true_operand, true_computation_rewritten,
162 false_operand, false_computation_rewritten);
163 });
164 }
165
DynamicConditional(XlaBuilder * builder,XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)166 XlaOp DynamicConditional(
167 XlaBuilder* builder, XlaOp branch_index,
168 absl::Span<const XlaComputation* const> branch_computations,
169 absl::Span<const XlaOp> branch_operands) {
170 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
171 std::vector<Shape> root_shapes;
172 root_shapes.reserve(branch_computations.size());
173 for (int64_t i = 0; i < branch_computations.size(); ++i) {
174 TF_ASSIGN_OR_RETURN(auto program_shape,
175 branch_computations[i]->GetProgramShape());
176 root_shapes.push_back(program_shape.result());
177 }
178 TF_RET_CHECK(!root_shapes.empty());
179 bool all_shapes_compatible =
180 absl::c_all_of(root_shapes, [&](const Shape& shape) {
181 return ShapeUtil::Compatible(root_shapes[0], shape);
182 });
183 if (all_shapes_compatible) {
184 // All shapes are compatible, fall back to static case.
185 return xla::Conditional(branch_index, branch_computations,
186 branch_operands);
187 }
188
189 std::vector<const Shape*> root_shapes_ptrs;
190 root_shapes_ptrs.reserve(root_shapes.size());
191 for (int64_t i = 0; i < root_shapes.size(); ++i) {
192 root_shapes_ptrs.push_back(&root_shapes[i]);
193 }
194
195 Shape max_shape = FindMaxShape(absl::MakeSpan(root_shapes_ptrs));
196
197 auto reconsile_branch =
198 [](const Shape& root_shape, const Shape& operand_shape,
199 const Shape& reference_root_shape,
200 const XlaComputation& computation) -> StatusOr<XlaComputation> {
201 xla::XlaBuilder builder("dynamic_builder");
202 auto param = xla::Parameter(&builder, 0, operand_shape, "param");
203 auto call = Call(&builder, computation, {param});
204
205 auto elem =
206 ReconsileBranchDifference(root_shape, reference_root_shape, call);
207 if (!elem.ok()) return elem.status();
208 return builder.Build();
209 };
210 std::vector<XlaComputation> rewritten_computations;
211 rewritten_computations.reserve(branch_computations.size());
212
213 for (int64_t i = 0; i < branch_computations.size(); ++i) {
214 TF_ASSIGN_OR_RETURN(Shape branch_operand_shape,
215 builder->GetShape(branch_operands[i]));
216
217 TF_ASSIGN_OR_RETURN(auto rewritten,
218 reconsile_branch(root_shapes[i], branch_operand_shape,
219 max_shape, *branch_computations[i]));
220 rewritten_computations.push_back(std::move(rewritten));
221 }
222 std::vector<const XlaComputation*> rewritten_computation_ptrs;
223 rewritten_computation_ptrs.reserve(branch_computations.size());
224 for (int64_t i = 0; i < branch_computations.size(); ++i) {
225 rewritten_computation_ptrs.push_back(&rewritten_computations[i]);
226 }
227 return xla::Conditional(branch_index, rewritten_computation_ptrs,
228 branch_operands);
229 });
230 }
231
SetDimensionSizeWithRebound(ValueInference * value_inference,XlaOp operand,XlaOp dimension_size,int64_t dimension)232 StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
233 XlaOp operand, XlaOp dimension_size,
234 int64_t dimension) {
235 auto inferred_bound_status_or = value_inference->AnalyzeConstant(
236 dimension_size, xla::ValueInferenceMode::kUpperBound);
237
238 auto dynamism_status_or = value_inference->AnalyzeIsDynamic(dimension_size);
239 TF_RETURN_IF_ERROR(inferred_bound_status_or.status());
240 TF_RETURN_IF_ERROR(dynamism_status_or.status());
241 if (inferred_bound_status_or->AllValid()) {
242 int64_t inferred_bound = inferred_bound_status_or->Get<int32_t>({}).value();
243 TF_ASSIGN_OR_RETURN(auto* shape_ptr,
244 operand.builder()->GetShapePtr(operand));
245 // Found a tighter bound, do a slice.
246 if (shape_ptr->dimensions(dimension) > inferred_bound) {
247 operand = xla::SliceInDim(operand, 0, inferred_bound, 1, dimension);
248 }
249 }
250 if (dynamism_status_or->Get<bool>({})) {
251 // dimension size is dynamic, make output dynamic by calling set dimension
252 // size.
253 operand = xla::SetDimensionSize(operand, dimension_size, dimension);
254 }
255 return operand;
256 }
257
SetAllDimensionSizes(ValueInference * value_inference,XlaOp operand,XlaOp size_vector)258 StatusOr<XlaOp> SetAllDimensionSizes(ValueInference* value_inference,
259 XlaOp operand, XlaOp size_vector) {
260 auto builder = value_inference->builder();
261 TF_RETURN_IF_ERROR(builder->GetCurrentStatus());
262 TF_ASSIGN_OR_RETURN(auto shape_ptr, builder->GetShapePtr(operand));
263
264 for (int64_t i = 0; i < shape_ptr->rank(); ++i) {
265 // If a dimension is dynamic, call set-dimension-size on the output.
266 auto dim_size = xla::Slice(size_vector, {i}, {i + 1}, {1});
267 dim_size = xla::Reshape(dim_size, {});
268 dim_size = xla::ConvertElementType(dim_size, xla::S32);
269 TF_ASSIGN_OR_RETURN(auto dynamism,
270 value_inference->AnalyzeIsDynamic(dim_size));
271 if (dynamism.Get<bool>({})) {
272 operand = xla::SetDimensionSize(operand, dim_size, i);
273 }
274 }
275 return operand;
276 }
277 } // namespace xla
278