1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/loops.h"
17
18 #include "tensorflow/compiler/xla/client/lib/constants.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22
23 namespace xla {
24
WhileLoopHelper(const WhileLoopHelperConditionFunction & condition_function,const WhileLoopHelperBodyFunction & body_function,absl::Span<const XlaOp> initial_values,absl::string_view name,XlaBuilder * builder)25 StatusOr<std::vector<XlaOp>> WhileLoopHelper(
26 const WhileLoopHelperConditionFunction& condition_function,
27 const WhileLoopHelperBodyFunction& body_function,
28 absl::Span<const XlaOp> initial_values, absl::string_view name,
29 XlaBuilder* builder) {
30 int arity = initial_values.size();
31 std::vector<Shape> var_shapes;
32 var_shapes.reserve(arity);
33 for (const XlaOp& input : initial_values) {
34 TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
35 var_shapes.push_back(std::move(shape));
36 }
37 Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes);
38
39 // Unpacks a tuple into its component parts.
40 auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) {
41 std::vector<XlaOp> elements(arity);
42 for (int i = 0; i < arity; ++i) {
43 elements[i] = GetTupleElement(tuple, i);
44 }
45 return elements;
46 };
47
48 // Build the condition.
49 std::unique_ptr<XlaBuilder> cond_builder =
50 builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
51 {
52 auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
53
54 TF_RETURN_IF_ERROR(
55 condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
56 cond_builder.get())
57 .status());
58 }
59 TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
60
61 // Build the body.
62 std::unique_ptr<XlaBuilder> body_builder =
63 builder->CreateSubBuilder(absl::StrCat(name, "_body"));
64 {
65 auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter");
66
67 TF_ASSIGN_OR_RETURN(
68 auto result,
69 body_function(unpack_tuple(parameter, arity, body_builder.get()),
70 body_builder.get()));
71
72 TF_RET_CHECK(result.size() == initial_values.size());
73 Tuple(body_builder.get(), result);
74 }
75 TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
76
77 auto outputs = While(cond, body, Tuple(builder, initial_values));
78
79 return unpack_tuple(outputs, arity, builder);
80 }
81
ForEachIndex(int64 num_iterations,PrimitiveType num_iterations_type,const ForEachIndexBodyFunction & body_function,absl::Span<const XlaOp> initial_values,absl::string_view name,XlaBuilder * builder)82 StatusOr<std::vector<XlaOp>> ForEachIndex(
83 int64 num_iterations, PrimitiveType num_iterations_type,
84 const ForEachIndexBodyFunction& body_function,
85 absl::Span<const XlaOp> initial_values, absl::string_view name,
86 XlaBuilder* builder) {
87 auto while_cond_fn = [&](absl::Span<const XlaOp> values,
88 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
89 return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type,
90 num_iterations));
91 };
92 auto while_body_fn =
93 [&](absl::Span<const XlaOp> values,
94 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
95 XlaOp iteration = values[0];
96
97 std::vector<XlaOp> updated_values;
98 updated_values.reserve(values.size());
99 updated_values.push_back(Add(
100 iteration,
101 ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type))));
102
103 values.remove_prefix(1);
104 TF_ASSIGN_OR_RETURN(std::vector<XlaOp> body_outputs,
105 body_function(iteration, values, body_builder));
106 updated_values.insert(updated_values.end(), body_outputs.begin(),
107 body_outputs.end());
108 return updated_values;
109 };
110
111 std::vector<XlaOp> values;
112 values.reserve(initial_values.size() + 1);
113 values.push_back(
114 ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type)));
115 values.insert(values.end(), initial_values.begin(), initial_values.end());
116
117 TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
118 values, name, builder));
119 values.erase(values.begin(), values.begin() + 1);
120 return values;
121 }
122
123 } // namespace xla
124