• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 namespace xla {
22 namespace {
23 
ReconsileBranchDifference(const Shape & left_branch_shape,const Shape & right_branch_shape,XlaOp left_root)24 XlaOp ReconsileBranchDifference(const Shape& left_branch_shape,
25                                 const Shape& right_branch_shape,
26                                 XlaOp left_root) {
27   if (left_branch_shape.IsTuple()) {
28     // Invariant sanity check -- Left branch and right branch need to have
29     // compatible shapes.
30     CHECK(right_branch_shape.IsTuple() &&
31           left_branch_shape.tuple_shapes_size() ==
32               right_branch_shape.tuple_shapes_size());
33     // Recurse into sub-element.
34     std::vector<XlaOp> results;
35     results.reserve(left_branch_shape.tuple_shapes_size());
36     for (int64 i = 0; i < left_branch_shape.tuple_shapes_size(); ++i) {
37       XlaOp sub_tuple = GetTupleElement(left_root, i);
38       XlaOp elem = ReconsileBranchDifference(left_branch_shape.tuple_shapes(i),
39                                              right_branch_shape.tuple_shapes(i),
40                                              sub_tuple);
41       results.push_back(elem);
42     }
43     return Tuple(left_root.builder(), results);
44   }
45   XlaOp result = left_root;
46   // Invariant sanity check -- Left branch and right branch need to have
47   // compatible shapes.
48   CHECK(!right_branch_shape.IsTuple());
49   CHECK(left_branch_shape.rank() == right_branch_shape.rank());
50   for (int64 dim = 0; dim < left_branch_shape.rank(); ++dim) {
51     XlaOp original_dim = GetDimensionSize(result, dim);
52     if (left_branch_shape.dimensions(dim) <
53         right_branch_shape.dimensions(dim)) {
54       int64 diff = right_branch_shape.dimensions(dim) -
55                    left_branch_shape.dimensions(dim);
56 
57       result = PadInDim(
58           result, Zero(result.builder(), left_branch_shape.element_type()), dim,
59           0, diff);
60     }
61     if (left_branch_shape.dimensions(dim) !=
62         right_branch_shape.dimensions(dim)) {
63       result = SetDimensionSize(result, original_dim, dim);
64     }
65   }
66   return result;
67 }
68 }  // namespace
DynamicConditional(XlaBuilder * builder,XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)69 XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate,
70                          XlaOp true_operand,
71                          const XlaComputation& true_computation,
72                          XlaOp false_operand,
73                          const XlaComputation& false_computation) {
74   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75     auto true_shape =
76         true_computation.GetProgramShape().ConsumeValueOrDie().result();
77 
78     auto false_shape =
79         false_computation.GetProgramShape().ConsumeValueOrDie().result();
80 
81     if (ShapeUtil::Compatible(true_shape, false_shape)) {
82       return xla::Conditional(predicate, true_operand, true_computation,
83                               false_operand, false_computation);
84     }
85 
86     auto reconsile_branch = [](const Shape& root_shape,
87                                const Shape& operand_shape,
88                                const Shape& reference_root_shape,
89                                const XlaComputation& computation) {
90       xla::XlaBuilder builder("dynamic_builder");
91       auto param = xla::Parameter(&builder, 0, operand_shape, "param");
92       auto call = Call(&builder, computation, {param});
93 
94       ReconsileBranchDifference(root_shape, reference_root_shape, call);
95       return builder.Build();
96     };
97     TF_ASSIGN_OR_RETURN(
98         auto true_computation_rewritten,
99         reconsile_branch(true_shape,
100                          builder->GetShape(true_operand).ValueOrDie(),
101                          false_shape, true_computation));
102 
103     TF_ASSIGN_OR_RETURN(
104         auto false_computation_rewritten,
105         reconsile_branch(false_shape,
106                          builder->GetShape(false_operand).ValueOrDie(),
107                          true_shape, false_computation));
108     return xla::Conditional(predicate, true_operand, true_computation_rewritten,
109                             false_operand, false_computation_rewritten);
110   });
111 }
112 
SetDimensionSizeWithRebound(ValueInference * value_inference,XlaOp operand,XlaOp dimension_size,int64_t dimension)113 StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
114                                             XlaOp operand, XlaOp dimension_size,
115                                             int64_t dimension) {
116   auto inferred_bound_status_or = value_inference->AnalyzeConstant(
117       dimension_size, xla::ValueInferenceMode::kUpperBound);
118   TF_RETURN_IF_ERROR(inferred_bound_status_or.status());
119   if (inferred_bound_status_or->AllValid()) {
120     int64 inferred_bound = inferred_bound_status_or->Get<int32>({}).value();
121     TF_ASSIGN_OR_RETURN(auto* shape_ptr,
122                         operand.builder()->GetShapePtr(operand));
123     // Found a tighter bound, do a slice.
124     if (shape_ptr->dimensions(dimension) > inferred_bound)
125       operand = xla::SliceInDim(operand, 0, inferred_bound, 1, dimension);
126   }
127   operand = xla::SetDimensionSize(operand, dimension_size, dimension);
128   return operand;
129 }
130 }  // namespace xla
131