• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
18 
19 #include <functional>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 
28 namespace xla {
29 
30 // Function that builds a loop condition. Takes as input a sequence of input
31 // values, and returns a boolean value representing if the condition succeeds.
32 typedef std::function<StatusOr<XlaOp>(absl::Span<const XlaOp>, XlaBuilder*)>
33     WhileLoopHelperConditionFunction;
34 
35 // Function that builds a loop body. Takes as input a sequence of input values
36 // and returns a sequence of output values.
37 typedef std::function<StatusOr<std::vector<XlaOp>>(absl::Span<const XlaOp>,
38                                                    XlaBuilder*)>
39     WhileLoopHelperBodyFunction;
40 
41 // Helper function for building an XLA while loop, where the values carried by
42 // the loop are a tuple of values, e.g., (a, b, c):
43 // while(
44 //   condition: (a, b, c) -> bool,
45 //   body: (a, b, c) -> (a, b, c)
46 //   init: (a, b, c)
47 // )
48 // 'name' is a descriptive name for the loop.
49 StatusOr<std::vector<XlaOp>> WhileLoopHelper(
50     const WhileLoopHelperConditionFunction& condition_function,
51     const WhileLoopHelperBodyFunction& body_function,
52     absl::Span<const XlaOp> initial_values, absl::string_view name,
53     XlaBuilder* builder);
54 
55 // Builds an XLA loop that repeats a computation `num_iterations` times.
56 //
57 // The body function (ForEachIndexBodyFunction) takes as input a pair of
58 // (current iteration number, loop-carried values), and returns an updated
59 // vector of the loop-carried values.
60 typedef std::function<StatusOr<std::vector<XlaOp>>(
61     XlaOp, absl::Span<const XlaOp>, XlaBuilder*)>
62     ForEachIndexBodyFunction;
63 
64 StatusOr<std::vector<XlaOp>> ForEachIndex(
65     int64 num_iterations, PrimitiveType num_iterations_type,
66     const ForEachIndexBodyFunction& body_function,
67     absl::Span<const XlaOp> initial_values, absl::string_view name,
68     XlaBuilder* builder);
69 
70 }  // namespace xla
71 
72 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
73