1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 18 19 #include <string> 20 21 #include "absl/strings/string_view.h" 22 #include "llvm/IR/BasicBlock.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/Value.h" 25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 27 28 namespace xla { 29 // A thin wrapper around llvm_loop.h to make code generating structured control 30 // flow more readable. 31 class KernelSupportLibrary { 32 public: 33 // `b` is the llvm::IRBuilder instance used to generate LLVM IR. 34 // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop 35 // generated by this instance of KernelSupportLibrary. 36 explicit KernelSupportLibrary( 37 llvm::IRBuilder<>* b, 38 llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, 39 bool prevent_vectorization = true) b_(b)40 : b_(b), 41 unroll_mode_(unroll_mode), 42 prevent_vectorization_(prevent_vectorization) {} 43 44 // Generates the following control flow structure: 45 // 46 // if (`start` < `end`) { 47 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`; 48 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 49 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; 50 // } 51 Status ForWithStatus( 52 absl::string_view name, llvm::Value* start, llvm::Value* end, 53 llvm::Value* step, 54 const std::function<Status(llvm::Value* ind_var, 55 bool is_first_iteration)>& for_body_generator); 56 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)57 void For( 58 absl::string_view name, llvm::Value* start, llvm::Value* end, 59 llvm::Value* step, 60 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 61 for_body_generator) { 62 CHECK_EQ(Status::OK(), 63 ForWithStatus( 64 name, start, end, step, 65 [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { 66 for_body_generator(ind_var, is_first_iteration); 67 return Status::OK(); 68 })); 69 } 70 ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)71 Status ForWithStatus( 72 absl::string_view name, int64 start, int64 end, int64 step, 73 const std::function<Status( 74 llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { 75 return ForWithStatus(name, /*start=*/b_->getInt64(start), 76 /*end=*/b_->getInt64(end), 77 /*step=*/b_->getInt64(step), for_body_generator); 78 } 79 For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)80 void For( 81 absl::string_view name, int64 start, int64 end, int64 step, 82 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 83 for_body_generator) { 84 For(name, /*start=*/b_->getInt64(start), 85 /*end=*/b_->getInt64(end), 86 /*step=*/b_->getInt64(step), for_body_generator); 87 } 88 89 // Generates the following control flow structure if `peel_first_iteration` is 90 // true: 91 // 92 // if (`start` < `end`) { 93 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`; 94 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 95 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`; 96 // } 97 // 98 // and the following if `peel_first_iteration` is false: 99 // 100 // for (i64 i = `start`; i s< `end`; i += `step`) 101 // `for_body_generator(/*ind_var=*/,i, 102 // /*is_first_iteration=*/,(i != `start`))`; 103 Status ForWithStatus( 104 absl::string_view name, llvm::Value* start, llvm::Value* end, 105 llvm::Value* step, bool peel_first_iteration, 106 const std::function<Status(llvm::Value* ind_var, 107 llvm::Value* is_first_iteration)>& 108 for_body_generator); 109 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)110 void For(absl::string_view name, llvm::Value* start, llvm::Value* end, 111 llvm::Value* step, bool peel_first_iteration, 112 const std::function<void(llvm::Value* ind_var, 113 llvm::Value* is_first_iteration)>& 114 for_body_generator) { 115 TF_CHECK_OK(ForWithStatus( 116 name, start, end, step, peel_first_iteration, 117 [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { 118 for_body_generator(ind_var, is_first_iteration); 119 return Status::OK(); 120 })); 121 } 122 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<Status (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)123 Status ForWithStatus( 124 absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, 125 bool peel_first_iteration, 126 const std::function<Status(llvm::Value* ind_var, 127 llvm::Value* is_first_iteration)>& 128 for_body_generator) { 129 return ForWithStatus( 130 name, /*start=*/start, /*end=*/end, 131 /*step=*/llvm::ConstantInt::get(start->getType(), step), 132 peel_first_iteration, for_body_generator); 133 } 134 For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)135 void For(absl::string_view name, llvm::Value* start, llvm::Value* end, 136 int64 step, bool peel_first_iteration, 137 const std::function<void(llvm::Value* ind_var, 138 llvm::Value* is_first_iteration)>& 139 for_body_generator) { 140 For(name, /*start=*/start, /*end=*/end, 141 /*step=*/llvm::ConstantInt::get(start->getType(), step), 142 peel_first_iteration, for_body_generator); 143 } 144 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)145 Status ForWithStatus( 146 absl::string_view name, llvm::Value* start, llvm::Value* end, 147 llvm::Value* step, 148 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 149 return ForWithStatus(name, start, end, step, 150 /*peel_first_iteration=*/false, 151 [&](llvm::Value* indvar, llvm::Value*) -> Status { 152 return for_body_generator(indvar); 153 }); 154 } 155 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)156 void For( 157 absl::string_view name, llvm::Value* start, llvm::Value* end, 158 llvm::Value* step, 159 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 160 For(name, start, end, step, 161 /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { 162 return for_body_generator(indvar); 163 }); 164 } 165 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)166 Status ForWithStatus( 167 absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, 168 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 169 return ForWithStatus(name, start, end, 170 llvm::ConstantInt::get(start->getType(), step), 171 /*peel_first_iteration=*/false, 172 [&](llvm::Value* indvar, llvm::Value*) -> Status { 173 return for_body_generator(indvar); 174 }); 175 } 176 For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)177 void For( 178 absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, 179 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 180 For(name, start, end, llvm::ConstantInt::get(start->getType(), step), 181 for_body_generator); 182 } 183 ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)184 Status ForWithStatus( 185 absl::string_view name, int64 start, int64 end, int64 step, 186 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 187 return ForWithStatus(name, /*start=*/b_->getInt64(start), 188 /*end=*/b_->getInt64(end), 189 /*step=*/b_->getInt64(step), for_body_generator); 190 } 191 For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)192 void For( 193 absl::string_view name, int64 start, int64 end, int64 step, 194 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 195 For(name, /*start=*/b_->getInt64(start), 196 /*end=*/b_->getInt64(end), 197 /*step=*/b_->getInt64(step), for_body_generator); 198 } 199 200 // Generates the following control flow structure: 201 // 202 // if (`condition`) 203 // `true_block_generator()`; 204 // else 205 // `false_block_generator()`; 206 Status IfWithStatus( 207 absl::string_view name, llvm::Value* condition, 208 const std::function<Status()>& true_block_generator, 209 const std::function<Status()>& false_block_generator = []() -> Status { 210 return Status::OK(); 211 }); 212 213 Status IfWithStatus( 214 llvm::Value* condition, 215 const std::function<Status()>& true_block_generator, 216 const std::function<Status()>& false_block_generator = []() -> Status { 217 return Status::OK(); 218 }) { 219 return IfWithStatus("", condition, true_block_generator, 220 false_block_generator); 221 } 222 223 void If( 224 llvm::Value* condition, const std::function<void()>& true_block_generator, 225 const std::function<void()>& false_block_generator = []() {}) { 226 If("", condition, true_block_generator, false_block_generator); 227 } 228 229 void If( 230 absl::string_view name, llvm::Value* condition, 231 const std::function<void()>& true_block_generator, 232 const std::function<void()>& false_block_generator = []() {}) { 233 TF_CHECK_OK(IfWithStatus( 234 name, condition, 235 [&]() { 236 true_block_generator(); 237 return Status::OK(); 238 }, 239 [&]() { 240 false_block_generator(); 241 return Status::OK(); 242 })); 243 } 244 245 using ArgumentVector = absl::Span<llvm::Value* const>; 246 247 // Generates the following control flow structure: 248 // 249 // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { 250 // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); 251 // } 252 // 253 // ... 254 // call @`kernel_name`(arguments[0], arguments[1] ...) 255 // ... 256 // 257 // If a function called `kernel_name` is already present in the module then 258 // that function is re-used. In that sense we're using the llvm::Module as a 259 // cache of outlined kernels, keyed by function name. 260 // 261 // If any of the values in `arguments` is nullptr (i.e. a nullptr 262 // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass 263 // in a nullptr llvm::Value* in its position to `kernel_body_generator`. 264 // Currently we only support at most one nullptr value in `arguments`. 265 static void EmitAndCallOutlinedKernel( 266 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 267 absl::string_view kernel_name, ArgumentVector arguments, 268 const std::function<void(ArgumentVector)>& kernel_body_generator); 269 270 // Thin wrappers around the more general EmitAndCallOutlinedKernel above. EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)271 static void EmitAndCallOutlinedKernel( 272 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 273 absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, 274 llvm::Value* arg2, 275 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>& 276 kernel_body_generator) { 277 EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2}, 278 [&](ArgumentVector args) { 279 kernel_body_generator(args[0], args[1], 280 args[2]); 281 }); 282 } 283 EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)284 static void EmitAndCallOutlinedKernel( 285 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 286 absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, 287 llvm::Value* arg2, llvm::Value* arg3, 288 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*, 289 llvm::Value*)>& kernel_body_generator) { 290 EmitAndCallOutlinedKernel( 291 module_config, b, kernel_name, {arg0, arg1, arg2, arg3}, 292 [&](ArgumentVector args) { 293 kernel_body_generator(args[0], args[1], args[2], args[3]); 294 }); 295 } 296 297 private: 298 llvm::IRBuilder<>* b_; 299 llvm_ir::UnrollMode unroll_mode_; 300 bool prevent_vectorization_; 301 }; 302 } // namespace xla 303 304 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 305