• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 namespace {
29 
30 // Given a list of shapes, create a shape whose dimensions are largest among all
31 // inputs.
32 //
33 // e.g.,
34 // shape_a = f32[10, 50]
35 // shape_b = f32[100, 10]
36 //
37 // result = f32[max(shape_a[0], shape_b[0]), max(shape_a[1], shape_b[1])]
38 //        = f32[100, 50]
FindMaxShape(absl::Span<const Shape * > shapes)39 Shape FindMaxShape(absl::Span<const Shape*> shapes) {
40   CHECK(!shapes.empty());
41   if (shapes[0]->IsTuple()) {
42     // Recurse into sub-element.
43     std::vector<Shape> results;
44     results.reserve(shapes[0]->tuple_shapes_size());
45     for (int i = 0; i < shapes[0]->tuple_shapes_size(); ++i) {
46       std::vector<const Shape*> subshapes;
47       subshapes.reserve(shapes.size());
48       for (int64_t j = 0; j < shapes.size(); ++j) {
49         subshapes.push_back(&shapes[j]->tuple_shapes(i));
50       }
51       results.push_back(FindMaxShape(absl::MakeSpan(subshapes)));
52     }
53     return ShapeUtil::MakeTupleShape(results);
54   }
55   Shape result = *shapes[0];
56 
57   for (const Shape* shape : shapes) {
58     CHECK(result.rank() == shape->rank());
59     for (int64_t dim = 0; dim < result.rank(); ++dim) {
60       if (shape->dimensions(dim) > result.dimensions(dim)) {
61         result.set_dimensions(dim, shape->dimensions(dim));
62       }
63     }
64   }
65   return result;
66 }
67 
ReconsileBranchDifference(const Shape & left_branch_shape,const Shape & right_branch_shape,XlaOp left_root)68 StatusOr<XlaOp> ReconsileBranchDifference(const Shape& left_branch_shape,
69                                           const Shape& right_branch_shape,
70                                           XlaOp left_root) {
71   if (left_branch_shape.IsTuple()) {
72     // Invariant sanity check -- Left branch and right branch need to have
73     // compatible shapes.
74     CHECK(right_branch_shape.IsTuple() &&
75           left_branch_shape.tuple_shapes_size() ==
76               right_branch_shape.tuple_shapes_size());
77     // Recurse into sub-element.
78     std::vector<XlaOp> results;
79     results.reserve(left_branch_shape.tuple_shapes_size());
80     for (int i = 0; i < left_branch_shape.tuple_shapes_size(); ++i) {
81       XlaOp sub_tuple = GetTupleElement(left_root, i);
82       TF_ASSIGN_OR_RETURN(XlaOp elem,
83                           ReconsileBranchDifference(
84                               left_branch_shape.tuple_shapes(i),
85                               right_branch_shape.tuple_shapes(i), sub_tuple));
86       results.push_back(elem);
87     }
88     return Tuple(left_root.builder(), results);
89   }
90   XlaOp result = left_root;
91   // Invariant sanity check -- Left branch and right branch need to have
92   // compatible shapes.
93   if (right_branch_shape.IsTuple()) {
94     return InvalidArgument(
95         "right_branch_shape should not be a tuple, received %s",
96         right_branch_shape.DebugString());
97   }
98   if (left_branch_shape.rank() != right_branch_shape.rank()) {
99     return InvalidArgument(
100         "left_branch_shape.rank() != right_branch_shape.rank() (%d vs %d)",
101         left_branch_shape.rank(), right_branch_shape.rank());
102   }
103   for (int64_t dim = 0; dim < left_branch_shape.rank(); ++dim) {
104     XlaOp original_dim = GetDimensionSize(result, dim);
105     if (left_branch_shape.dimensions(dim) <
106         right_branch_shape.dimensions(dim)) {
107       int64_t diff = right_branch_shape.dimensions(dim) -
108                      left_branch_shape.dimensions(dim);
109 
110       result = PadInDim(
111           result, Zero(result.builder(), left_branch_shape.element_type()), dim,
112           0, diff);
113     }
114     if (left_branch_shape.dimensions(dim) !=
115         right_branch_shape.dimensions(dim)) {
116       result = SetDimensionSize(result, original_dim, dim);
117     }
118   }
119   return result;
120 }
121 }  // namespace
DynamicConditional(XlaBuilder * builder,XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)122 XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate,
123                          XlaOp true_operand,
124                          const XlaComputation& true_computation,
125                          XlaOp false_operand,
126                          const XlaComputation& false_computation) {
127   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
128     auto true_shape = true_computation.GetProgramShape().value().result();
129 
130     auto false_shape = false_computation.GetProgramShape().value().result();
131 
132     if (ShapeUtil::Compatible(true_shape, false_shape)) {
133       return xla::Conditional(predicate, true_operand, true_computation,
134                               false_operand, false_computation);
135     }
136 
137     auto reconsile_branch =
138         [](const Shape& root_shape, const Shape& operand_shape,
139            const Shape& reference_root_shape,
140            const XlaComputation& computation) -> StatusOr<XlaComputation> {
141       xla::XlaBuilder builder("dynamic_builder");
142       auto param = xla::Parameter(&builder, 0, operand_shape, "param");
143       auto call = Call(&builder, computation, {param});
144 
145       auto elem =
146           ReconsileBranchDifference(root_shape, reference_root_shape, call);
147       if (!elem.ok()) return elem.status();
148       return builder.Build();
149     };
150     TF_ASSIGN_OR_RETURN(
151         auto true_computation_rewritten,
152         reconsile_branch(true_shape,
153                          builder->GetShape(true_operand).ValueOrDie(),
154                          false_shape, true_computation));
155 
156     TF_ASSIGN_OR_RETURN(
157         auto false_computation_rewritten,
158         reconsile_branch(false_shape,
159                          builder->GetShape(false_operand).ValueOrDie(),
160                          true_shape, false_computation));
161     return xla::Conditional(predicate, true_operand, true_computation_rewritten,
162                             false_operand, false_computation_rewritten);
163   });
164 }
165 
DynamicConditional(XlaBuilder * builder,XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)166 XlaOp DynamicConditional(
167     XlaBuilder* builder, XlaOp branch_index,
168     absl::Span<const XlaComputation* const> branch_computations,
169     absl::Span<const XlaOp> branch_operands) {
170   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
171     std::vector<Shape> root_shapes;
172     root_shapes.reserve(branch_computations.size());
173     for (int64_t i = 0; i < branch_computations.size(); ++i) {
174       TF_ASSIGN_OR_RETURN(auto program_shape,
175                           branch_computations[i]->GetProgramShape());
176       root_shapes.push_back(program_shape.result());
177     }
178     TF_RET_CHECK(!root_shapes.empty());
179     bool all_shapes_compatible =
180         absl::c_all_of(root_shapes, [&](const Shape& shape) {
181           return ShapeUtil::Compatible(root_shapes[0], shape);
182         });
183     if (all_shapes_compatible) {
184       // All shapes are compatible, fall back to static case.
185       return xla::Conditional(branch_index, branch_computations,
186                               branch_operands);
187     }
188 
189     std::vector<const Shape*> root_shapes_ptrs;
190     root_shapes_ptrs.reserve(root_shapes.size());
191     for (int64_t i = 0; i < root_shapes.size(); ++i) {
192       root_shapes_ptrs.push_back(&root_shapes[i]);
193     }
194 
195     Shape max_shape = FindMaxShape(absl::MakeSpan(root_shapes_ptrs));
196 
197     auto reconsile_branch =
198         [](const Shape& root_shape, const Shape& operand_shape,
199            const Shape& reference_root_shape,
200            const XlaComputation& computation) -> StatusOr<XlaComputation> {
201       xla::XlaBuilder builder("dynamic_builder");
202       auto param = xla::Parameter(&builder, 0, operand_shape, "param");
203       auto call = Call(&builder, computation, {param});
204 
205       auto elem =
206           ReconsileBranchDifference(root_shape, reference_root_shape, call);
207       if (!elem.ok()) return elem.status();
208       return builder.Build();
209     };
210     std::vector<XlaComputation> rewritten_computations;
211     rewritten_computations.reserve(branch_computations.size());
212 
213     for (int64_t i = 0; i < branch_computations.size(); ++i) {
214       TF_ASSIGN_OR_RETURN(Shape branch_operand_shape,
215                           builder->GetShape(branch_operands[i]));
216 
217       TF_ASSIGN_OR_RETURN(auto rewritten,
218                           reconsile_branch(root_shapes[i], branch_operand_shape,
219                                            max_shape, *branch_computations[i]));
220       rewritten_computations.push_back(std::move(rewritten));
221     }
222     std::vector<const XlaComputation*> rewritten_computation_ptrs;
223     rewritten_computation_ptrs.reserve(branch_computations.size());
224     for (int64_t i = 0; i < branch_computations.size(); ++i) {
225       rewritten_computation_ptrs.push_back(&rewritten_computations[i]);
226     }
227     return xla::Conditional(branch_index, rewritten_computation_ptrs,
228                             branch_operands);
229   });
230 }
231 
SetDimensionSizeWithRebound(ValueInference * value_inference,XlaOp operand,XlaOp dimension_size,int64_t dimension)232 StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
233                                             XlaOp operand, XlaOp dimension_size,
234                                             int64_t dimension) {
235   auto inferred_bound_status_or = value_inference->AnalyzeConstant(
236       dimension_size, xla::ValueInferenceMode::kUpperBound);
237 
238   auto dynamism_status_or = value_inference->AnalyzeIsDynamic(dimension_size);
239   TF_RETURN_IF_ERROR(inferred_bound_status_or.status());
240   TF_RETURN_IF_ERROR(dynamism_status_or.status());
241   if (inferred_bound_status_or->AllValid()) {
242     int64_t inferred_bound = inferred_bound_status_or->Get<int32_t>({}).value();
243     TF_ASSIGN_OR_RETURN(auto* shape_ptr,
244                         operand.builder()->GetShapePtr(operand));
245     // Found a tighter bound, do a slice.
246     if (shape_ptr->dimensions(dimension) > inferred_bound) {
247       operand = xla::SliceInDim(operand, 0, inferred_bound, 1, dimension);
248     }
249   }
250   if (dynamism_status_or->Get<bool>({})) {
251     // dimension size is dynamic, make output dynamic by calling set dimension
252     // size.
253     operand = xla::SetDimensionSize(operand, dimension_size, dimension);
254   }
255   return operand;
256 }
257 
SetAllDimensionSizes(ValueInference * value_inference,XlaOp operand,XlaOp size_vector)258 StatusOr<XlaOp> SetAllDimensionSizes(ValueInference* value_inference,
259                                      XlaOp operand, XlaOp size_vector) {
260   auto builder = value_inference->builder();
261   TF_RETURN_IF_ERROR(builder->GetCurrentStatus());
262   TF_ASSIGN_OR_RETURN(auto shape_ptr, builder->GetShapePtr(operand));
263 
264   for (int64_t i = 0; i < shape_ptr->rank(); ++i) {
265     // If a dimension is dynamic, call set-dimension-size on the output.
266     auto dim_size = xla::Slice(size_vector, {i}, {i + 1}, {1});
267     dim_size = xla::Reshape(dim_size, {});
268     dim_size = xla::ConvertElementType(dim_size, xla::S32);
269     TF_ASSIGN_OR_RETURN(auto dynamism,
270                         value_inference->AnalyzeIsDynamic(dim_size));
271     if (dynamism.Get<bool>({})) {
272       operand = xla::SetDimensionSize(operand, dim_size, i);
273     }
274   }
275   return operand;
276 }
277 }  // namespace xla
278