• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
18 
19 #include <string>
20 
21 #include "absl/strings/string_view.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Value.h"
25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 
28 namespace xla {
29 // A thin wrapper around llvm_loop.h to make code generating structured control
30 // flow more readable.
31 class KernelSupportLibrary {
32  public:
33   // `b` is the llvm::IRBuilder instance used to generate LLVM IR.
34   // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop
35   // generated by this instance of KernelSupportLibrary.
36   explicit KernelSupportLibrary(
37       llvm::IRBuilder<>* b,
38       llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll,
39       bool prevent_vectorization = true)
b_(b)40       : b_(b),
41         unroll_mode_(unroll_mode),
42         prevent_vectorization_(prevent_vectorization) {}
43 
44   // Generates the following control flow structure:
45   //
46   //   if (`start` < `end`) {
47   //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`;
48   //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
49   //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
50   //   }
51   Status ForWithStatus(
52       absl::string_view name, llvm::Value* start, llvm::Value* end,
53       llvm::Value* step,
54       const std::function<Status(llvm::Value* ind_var,
55                                  bool is_first_iteration)>& for_body_generator);
56 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)57   void For(
58       absl::string_view name, llvm::Value* start, llvm::Value* end,
59       llvm::Value* step,
60       const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
61           for_body_generator) {
62     CHECK_EQ(Status::OK(),
63              ForWithStatus(
64                  name, start, end, step,
65                  [&](llvm::Value* ind_var, bool is_first_iteration) -> Status {
66                    for_body_generator(ind_var, is_first_iteration);
67                    return Status::OK();
68                  }));
69   }
70 
ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)71   Status ForWithStatus(
72       absl::string_view name, int64 start, int64 end, int64 step,
73       const std::function<Status(
74           llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) {
75     return ForWithStatus(name, /*start=*/b_->getInt64(start),
76                          /*end=*/b_->getInt64(end),
77                          /*step=*/b_->getInt64(step), for_body_generator);
78   }
79 
For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)80   void For(
81       absl::string_view name, int64 start, int64 end, int64 step,
82       const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
83           for_body_generator) {
84     For(name, /*start=*/b_->getInt64(start),
85         /*end=*/b_->getInt64(end),
86         /*step=*/b_->getInt64(step), for_body_generator);
87   }
88 
89   // Generates the following control flow structure if `peel_first_iteration` is
90   // true:
91   //
92   //   if (`start` < `end`) {
93   //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`;
94   //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
95   //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`;
96   //   }
97   //
98   // and the following if `peel_first_iteration` is false:
99   //
100   //   for (i64 i = `start`; i s< `end`; i += `step`)
101   //     `for_body_generator(/*ind_var=*/,i,
102   //                         /*is_first_iteration=*/,(i != `start`))`;
103   Status ForWithStatus(
104       absl::string_view name, llvm::Value* start, llvm::Value* end,
105       llvm::Value* step, bool peel_first_iteration,
106       const std::function<Status(llvm::Value* ind_var,
107                                  llvm::Value* is_first_iteration)>&
108           for_body_generator);
109 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)110   void For(absl::string_view name, llvm::Value* start, llvm::Value* end,
111            llvm::Value* step, bool peel_first_iteration,
112            const std::function<void(llvm::Value* ind_var,
113                                     llvm::Value* is_first_iteration)>&
114                for_body_generator) {
115     TF_CHECK_OK(ForWithStatus(
116         name, start, end, step, peel_first_iteration,
117         [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status {
118           for_body_generator(ind_var, is_first_iteration);
119           return Status::OK();
120         }));
121   }
122 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<Status (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)123   Status ForWithStatus(
124       absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
125       bool peel_first_iteration,
126       const std::function<Status(llvm::Value* ind_var,
127                                  llvm::Value* is_first_iteration)>&
128           for_body_generator) {
129     return ForWithStatus(
130         name, /*start=*/start, /*end=*/end,
131         /*step=*/llvm::ConstantInt::get(start->getType(), step),
132         peel_first_iteration, for_body_generator);
133   }
134 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)135   void For(absl::string_view name, llvm::Value* start, llvm::Value* end,
136            int64 step, bool peel_first_iteration,
137            const std::function<void(llvm::Value* ind_var,
138                                     llvm::Value* is_first_iteration)>&
139                for_body_generator) {
140     For(name, /*start=*/start, /*end=*/end,
141         /*step=*/llvm::ConstantInt::get(start->getType(), step),
142         peel_first_iteration, for_body_generator);
143   }
144 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)145   Status ForWithStatus(
146       absl::string_view name, llvm::Value* start, llvm::Value* end,
147       llvm::Value* step,
148       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
149     return ForWithStatus(name, start, end, step,
150                          /*peel_first_iteration=*/false,
151                          [&](llvm::Value* indvar, llvm::Value*) -> Status {
152                            return for_body_generator(indvar);
153                          });
154   }
155 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)156   void For(
157       absl::string_view name, llvm::Value* start, llvm::Value* end,
158       llvm::Value* step,
159       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
160     For(name, start, end, step,
161         /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) {
162           return for_body_generator(indvar);
163         });
164   }
165 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)166   Status ForWithStatus(
167       absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
168       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
169     return ForWithStatus(name, start, end,
170                          llvm::ConstantInt::get(start->getType(), step),
171                          /*peel_first_iteration=*/false,
172                          [&](llvm::Value* indvar, llvm::Value*) -> Status {
173                            return for_body_generator(indvar);
174                          });
175   }
176 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)177   void For(
178       absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
179       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
180     For(name, start, end, llvm::ConstantInt::get(start->getType(), step),
181         for_body_generator);
182   }
183 
ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)184   Status ForWithStatus(
185       absl::string_view name, int64 start, int64 end, int64 step,
186       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
187     return ForWithStatus(name, /*start=*/b_->getInt64(start),
188                          /*end=*/b_->getInt64(end),
189                          /*step=*/b_->getInt64(step), for_body_generator);
190   }
191 
For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)192   void For(
193       absl::string_view name, int64 start, int64 end, int64 step,
194       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
195     For(name, /*start=*/b_->getInt64(start),
196         /*end=*/b_->getInt64(end),
197         /*step=*/b_->getInt64(step), for_body_generator);
198   }
199 
200   // Generates the following control flow structure:
201   //
202   //   if (`condition`)
203   //     `true_block_generator()`;
204   //   else
205   //      `false_block_generator()`;
206   // The else is skipped if false_block_generator is null.
207   Status IfWithStatus(
208       absl::string_view name, llvm::Value* condition,
209       const std::function<Status()>& true_block_generator,
210       const std::function<Status()>& false_block_generator = nullptr);
211 
212   Status IfWithStatus(
213       llvm::Value* condition,
214       const std::function<Status()>& true_block_generator,
215       const std::function<Status()>& false_block_generator = []() -> Status {
216         return Status::OK();
217       }) {
218     return IfWithStatus("", condition, true_block_generator,
219                         false_block_generator);
220   }
221 
222   void If(llvm::Value* condition,
223           const std::function<void()>& true_block_generator,
224           const std::function<void()>& false_block_generator = nullptr) {
225     If("", condition, true_block_generator, false_block_generator);
226   }
227 
228   void If(absl::string_view name, llvm::Value* condition,
229           const std::function<void()>& true_block_generator,
230           const std::function<void()>& false_block_generator = nullptr) {
231     if (false_block_generator != nullptr) {
232       TF_CHECK_OK(IfWithStatus(
233           name, condition,
234           [&]() {
235             true_block_generator();
236             return Status::OK();
237           },
238           [&]() {
239             false_block_generator();
240             return Status::OK();
241           }));
242     } else {
243       TF_CHECK_OK(IfWithStatus(name, condition, [&]() {
244         true_block_generator();
245         return Status::OK();
246       }));
247     }
248   }
249 
250   using ArgumentVector = absl::Span<llvm::Value* const>;
251 
252   // Generates the following control flow structure:
253   //
254   //  define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) {
255   //    kernel_body_generator({arg0, arg1, ... arg`arguments.size()`});
256   //  }
257   //
258   //  ...
259   //  call @`kernel_name`(arguments[0], arguments[1] ...)
260   //  ...
261   //
262   // If a function called `kernel_name` is already present in the module then
263   // that function is re-used.  In that sense we're using the llvm::Module as a
264   // cache of outlined kernels, keyed by function name.
265   //
266   // If any of the values in `arguments` is nullptr (i.e. a nullptr
267   // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass
268   // in a nullptr llvm::Value* in its position to `kernel_body_generator`.
269   // Currently we only support at most one nullptr value in `arguments`.
270   static void EmitAndCallOutlinedKernel(
271       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
272       absl::string_view kernel_name, ArgumentVector arguments,
273       const std::function<void(ArgumentVector)>& kernel_body_generator);
274 
275   // Thin wrappers around the more general EmitAndCallOutlinedKernel above.
EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)276   static void EmitAndCallOutlinedKernel(
277       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
278       absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
279       llvm::Value* arg2,
280       const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
281           kernel_body_generator) {
282     EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2},
283                               [&](ArgumentVector args) {
284                                 kernel_body_generator(args[0], args[1],
285                                                       args[2]);
286                               });
287   }
288 
EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)289   static void EmitAndCallOutlinedKernel(
290       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
291       absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
292       llvm::Value* arg2, llvm::Value* arg3,
293       const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
294                                llvm::Value*)>& kernel_body_generator) {
295     EmitAndCallOutlinedKernel(
296         module_config, b, kernel_name, {arg0, arg1, arg2, arg3},
297         [&](ArgumentVector args) {
298           kernel_body_generator(args[0], args[1], args[2], args[3]);
299         });
300   }
301 
302  private:
303   llvm::IRBuilder<>* b_;
304   llvm_ir::UnrollMode unroll_mode_;
305   bool prevent_vectorization_;
306 };
307 }  // namespace xla
308 
309 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
310