1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/loops.h"
17 
18 #include "tensorflow/compiler/xla/client/lib/constants.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 
23 namespace xla {
24 
WhileLoopHelper(const WhileLoopHelperConditionFunction & condition_function,const WhileLoopHelperBodyFunction & body_function,absl::Span<const XlaOp> initial_values,absl::string_view name,XlaBuilder * builder)25 StatusOr<std::vector<XlaOp>> WhileLoopHelper(
26     const WhileLoopHelperConditionFunction& condition_function,
27     const WhileLoopHelperBodyFunction& body_function,
28     absl::Span<const XlaOp> initial_values, absl::string_view name,
29     XlaBuilder* builder) {
30   int arity = initial_values.size();
31   std::vector<Shape> var_shapes;
32   var_shapes.reserve(arity);
33   for (const XlaOp& input : initial_values) {
34     TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
35     var_shapes.push_back(std::move(shape));
36   }
37   Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes);
38 
39   // Unpacks a tuple into its component parts.
40   auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) {
41     std::vector<XlaOp> elements(arity);
42     for (int i = 0; i < arity; ++i) {
43       elements[i] = GetTupleElement(tuple, i);
44     }
45     return elements;
46   };
47 
48   // Build the condition.
49   std::unique_ptr<XlaBuilder> cond_builder =
50       builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
51   {
52     auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
53 
54     TF_RETURN_IF_ERROR(
55         condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
56                            cond_builder.get())
57             .status());
58   }
59   TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
60 
61   // Build the body.
62   std::unique_ptr<XlaBuilder> body_builder =
63       builder->CreateSubBuilder(absl::StrCat(name, "_body"));
64   {
65     auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter");
66 
67     TF_ASSIGN_OR_RETURN(
68         auto result,
69         body_function(unpack_tuple(parameter, arity, body_builder.get()),
70                       body_builder.get()));
71 
72     TF_RET_CHECK(result.size() == initial_values.size());
73     Tuple(body_builder.get(), result);
74   }
75   TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
76 
77   auto outputs = While(cond, body, Tuple(builder, initial_values));
78 
79   return unpack_tuple(outputs, arity, builder);
80 }
81 
ForEachIndex(int64_t num_iterations,PrimitiveType num_iterations_type,const ForEachIndexBodyFunction & body_function,absl::Span<const XlaOp> initial_values,absl::string_view name,XlaBuilder * builder)82 StatusOr<std::vector<XlaOp>> ForEachIndex(
83     int64_t num_iterations, PrimitiveType num_iterations_type,
84     const ForEachIndexBodyFunction& body_function,
85     absl::Span<const XlaOp> initial_values, absl::string_view name,
86     XlaBuilder* builder) {
87   auto while_cond_fn = [&](absl::Span<const XlaOp> values,
88                            XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
89     return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type,
90                                             num_iterations));
91   };
92   auto while_body_fn =
93       [&](absl::Span<const XlaOp> values,
94           XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
95     XlaOp iteration = values[0];
96 
97     std::vector<XlaOp> updated_values;
98     updated_values.reserve(values.size());
99     updated_values.push_back(Add(
100         iteration,
101         ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type))));
102 
103     values.remove_prefix(1);
104     TF_ASSIGN_OR_RETURN(std::vector<XlaOp> body_outputs,
105                         body_function(iteration, values, body_builder));
106     updated_values.insert(updated_values.end(), body_outputs.begin(),
107                           body_outputs.end());
108     return updated_values;
109   };
110 
111   std::vector<XlaOp> values;
112   values.reserve(initial_values.size() + 1);
113   values.push_back(
114       ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type)));
115   values.insert(values.end(), initial_values.begin(), initial_values.end());
116 
117   TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
118                                               values, name, builder));
119   values.erase(values.begin(), values.begin() + 1);
120   return values;
121 }
122 
123 }  // namespace xla
124