• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
18 
19 #include <string>
20 
21 #include "absl/strings/string_view.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Value.h"
25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 
28 namespace xla {
29 // A thin wrapper around llvm_loop.h to make code generating structured control
30 // flow more readable.
31 class KernelSupportLibrary {
32  public:
33   // `b` is the llvm::IRBuilder instance used to generate LLVM IR.
34   // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop
35   // generated by this instance of KernelSupportLibrary.
36   explicit KernelSupportLibrary(
37       llvm::IRBuilder<>* b,
38       llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll,
39       bool prevent_vectorization = true)
b_(b)40       : b_(b),
41         unroll_mode_(unroll_mode),
42         prevent_vectorization_(prevent_vectorization) {}
43 
44   // Generates the following control flow structure:
45   //
46   //   if (`start` < `end`) {
47   //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`;
48   //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
49   //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
50   //   }
51   Status ForWithStatus(
52       absl::string_view name, llvm::Value* start, llvm::Value* end,
53       llvm::Value* step,
54       const std::function<Status(llvm::Value* ind_var,
55                                  bool is_first_iteration)>& for_body_generator);
56 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)57   void For(
58       absl::string_view name, llvm::Value* start, llvm::Value* end,
59       llvm::Value* step,
60       const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
61           for_body_generator) {
62     CHECK_EQ(Status::OK(),
63              ForWithStatus(
64                  name, start, end, step,
65                  [&](llvm::Value* ind_var, bool is_first_iteration) -> Status {
66                    for_body_generator(ind_var, is_first_iteration);
67                    return Status::OK();
68                  }));
69   }
70 
ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)71   Status ForWithStatus(
72       absl::string_view name, int64 start, int64 end, int64 step,
73       const std::function<Status(
74           llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) {
75     return ForWithStatus(name, /*start=*/b_->getInt64(start),
76                          /*end=*/b_->getInt64(end),
77                          /*step=*/b_->getInt64(step), for_body_generator);
78   }
79 
For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)80   void For(
81       absl::string_view name, int64 start, int64 end, int64 step,
82       const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
83           for_body_generator) {
84     For(name, /*start=*/b_->getInt64(start),
85         /*end=*/b_->getInt64(end),
86         /*step=*/b_->getInt64(step), for_body_generator);
87   }
88 
89   // Generates the following control flow structure if `peel_first_iteration` is
90   // true:
91   //
92   //   if (`start` < `end`) {
93   //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`;
94   //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
95   //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`;
96   //   }
97   //
98   // and the following if `peel_first_iteration` is false:
99   //
100   //   for (i64 i = `start`; i s< `end`; i += `step`)
101   //     `for_body_generator(/*ind_var=*/,i,
102   //                         /*is_first_iteration=*/,(i != `start`))`;
103   Status ForWithStatus(
104       absl::string_view name, llvm::Value* start, llvm::Value* end,
105       llvm::Value* step, bool peel_first_iteration,
106       const std::function<Status(llvm::Value* ind_var,
107                                  llvm::Value* is_first_iteration)>&
108           for_body_generator);
109 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)110   void For(absl::string_view name, llvm::Value* start, llvm::Value* end,
111            llvm::Value* step, bool peel_first_iteration,
112            const std::function<void(llvm::Value* ind_var,
113                                     llvm::Value* is_first_iteration)>&
114                for_body_generator) {
115     TF_CHECK_OK(ForWithStatus(
116         name, start, end, step, peel_first_iteration,
117         [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status {
118           for_body_generator(ind_var, is_first_iteration);
119           return Status::OK();
120         }));
121   }
122 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<Status (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)123   Status ForWithStatus(
124       absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
125       bool peel_first_iteration,
126       const std::function<Status(llvm::Value* ind_var,
127                                  llvm::Value* is_first_iteration)>&
128           for_body_generator) {
129     return ForWithStatus(
130         name, /*start=*/start, /*end=*/end,
131         /*step=*/llvm::ConstantInt::get(start->getType(), step),
132         peel_first_iteration, for_body_generator);
133   }
134 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)135   void For(absl::string_view name, llvm::Value* start, llvm::Value* end,
136            int64 step, bool peel_first_iteration,
137            const std::function<void(llvm::Value* ind_var,
138                                     llvm::Value* is_first_iteration)>&
139                for_body_generator) {
140     For(name, /*start=*/start, /*end=*/end,
141         /*step=*/llvm::ConstantInt::get(start->getType(), step),
142         peel_first_iteration, for_body_generator);
143   }
144 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)145   Status ForWithStatus(
146       absl::string_view name, llvm::Value* start, llvm::Value* end,
147       llvm::Value* step,
148       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
149     return ForWithStatus(name, start, end, step,
150                          /*peel_first_iteration=*/false,
151                          [&](llvm::Value* indvar, llvm::Value*) -> Status {
152                            return for_body_generator(indvar);
153                          });
154   }
155 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)156   void For(
157       absl::string_view name, llvm::Value* start, llvm::Value* end,
158       llvm::Value* step,
159       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
160     For(name, start, end, step,
161         /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) {
162           return for_body_generator(indvar);
163         });
164   }
165 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)166   Status ForWithStatus(
167       absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
168       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
169     return ForWithStatus(name, start, end,
170                          llvm::ConstantInt::get(start->getType(), step),
171                          /*peel_first_iteration=*/false,
172                          [&](llvm::Value* indvar, llvm::Value*) -> Status {
173                            return for_body_generator(indvar);
174                          });
175   }
176 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)177   void For(
178       absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
179       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
180     For(name, start, end, llvm::ConstantInt::get(start->getType(), step),
181         for_body_generator);
182   }
183 
ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)184   Status ForWithStatus(
185       absl::string_view name, int64 start, int64 end, int64 step,
186       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
187     return ForWithStatus(name, /*start=*/b_->getInt64(start),
188                          /*end=*/b_->getInt64(end),
189                          /*step=*/b_->getInt64(step), for_body_generator);
190   }
191 
For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)192   void For(
193       absl::string_view name, int64 start, int64 end, int64 step,
194       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
195     For(name, /*start=*/b_->getInt64(start),
196         /*end=*/b_->getInt64(end),
197         /*step=*/b_->getInt64(step), for_body_generator);
198   }
199 
200   // Generates the following control flow structure:
201   //
202   //   if (`condition`)
203   //     `true_block_generator()`;
204   //   else
205   //      `false_block_generator()`;
206   Status IfWithStatus(
207       absl::string_view name, llvm::Value* condition,
208       const std::function<Status()>& true_block_generator,
209       const std::function<Status()>& false_block_generator = []() -> Status {
210         return Status::OK();
211       });
212 
213   Status IfWithStatus(
214       llvm::Value* condition,
215       const std::function<Status()>& true_block_generator,
216       const std::function<Status()>& false_block_generator = []() -> Status {
217         return Status::OK();
218       }) {
219     return IfWithStatus("", condition, true_block_generator,
220                         false_block_generator);
221   }
222 
223   void If(
224       llvm::Value* condition, const std::function<void()>& true_block_generator,
225       const std::function<void()>& false_block_generator = []() {}) {
226     If("", condition, true_block_generator, false_block_generator);
227   }
228 
229   void If(
230       absl::string_view name, llvm::Value* condition,
231       const std::function<void()>& true_block_generator,
232       const std::function<void()>& false_block_generator = []() {}) {
233     TF_CHECK_OK(IfWithStatus(
234         name, condition,
235         [&]() {
236           true_block_generator();
237           return Status::OK();
238         },
239         [&]() {
240           false_block_generator();
241           return Status::OK();
242         }));
243   }
244 
245   using ArgumentVector = absl::Span<llvm::Value* const>;
246 
247   // Generates the following control flow structure:
248   //
249   //  define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) {
250   //    kernel_body_generator({arg0, arg1, ... arg`arguments.size()`});
251   //  }
252   //
253   //  ...
254   //  call @`kernel_name`(arguments[0], arguments[1] ...)
255   //  ...
256   //
257   // If a function called `kernel_name` is already present in the module then
258   // that function is re-used.  In that sense we're using the llvm::Module as a
259   // cache of outlined kernels, keyed by function name.
260   //
261   // If any of the values in `arguments` is nullptr (i.e. a nullptr
262   // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass
263   // in a nullptr llvm::Value* in its position to `kernel_body_generator`.
264   // Currently we only support at most one nullptr value in `arguments`.
265   static void EmitAndCallOutlinedKernel(
266       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
267       absl::string_view kernel_name, ArgumentVector arguments,
268       const std::function<void(ArgumentVector)>& kernel_body_generator);
269 
270   // Thin wrappers around the more general EmitAndCallOutlinedKernel above.
EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)271   static void EmitAndCallOutlinedKernel(
272       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
273       absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
274       llvm::Value* arg2,
275       const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
276           kernel_body_generator) {
277     EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2},
278                               [&](ArgumentVector args) {
279                                 kernel_body_generator(args[0], args[1],
280                                                       args[2]);
281                               });
282   }
283 
EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)284   static void EmitAndCallOutlinedKernel(
285       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
286       absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
287       llvm::Value* arg2, llvm::Value* arg3,
288       const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
289                                llvm::Value*)>& kernel_body_generator) {
290     EmitAndCallOutlinedKernel(
291         module_config, b, kernel_name, {arg0, arg1, arg2, arg3},
292         [&](ArgumentVector args) {
293           kernel_body_generator(args[0], args[1], args[2], args[3]);
294         });
295   }
296 
297  private:
298   llvm::IRBuilder<>* b_;
299   llvm_ir::UnrollMode unroll_mode_;
300   bool prevent_vectorization_;
301 };
302 }  // namespace xla
303 
304 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
305