1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 18 19 #include <string> 20 21 #include "absl/strings/string_view.h" 22 #include "llvm/IR/BasicBlock.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/Value.h" 25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 27 28 namespace xla { 29 // A thin wrapper around llvm_loop.h to make code generating structured control 30 // flow more readable. 31 class KernelSupportLibrary { 32 public: 33 // `b` is the llvm::IRBuilder instance used to generate LLVM IR. 34 // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop 35 // generated by this instance of KernelSupportLibrary. 36 explicit KernelSupportLibrary( 37 llvm::IRBuilder<>* b, 38 llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, 39 bool prevent_vectorization = true) b_(b)40 : b_(b), 41 unroll_mode_(unroll_mode), 42 prevent_vectorization_(prevent_vectorization) {} 43 44 // Generates the following control flow structure: 45 // 46 // if (`start` < `end`) { 47 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`; 48 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 49 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; 50 // } 51 Status ForWithStatus( 52 absl::string_view name, llvm::Value* start, llvm::Value* end, 53 llvm::Value* step, 54 const std::function<Status(llvm::Value* ind_var, 55 bool is_first_iteration)>& for_body_generator); 56 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)57 void For( 58 absl::string_view name, llvm::Value* start, llvm::Value* end, 59 llvm::Value* step, 60 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 61 for_body_generator) { 62 CHECK_EQ(Status::OK(), 63 ForWithStatus( 64 name, start, end, step, 65 [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { 66 for_body_generator(ind_var, is_first_iteration); 67 return Status::OK(); 68 })); 69 } 70 ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)71 Status ForWithStatus( 72 absl::string_view name, int64 start, int64 end, int64 step, 73 const std::function<Status( 74 llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { 75 return ForWithStatus(name, /*start=*/b_->getInt64(start), 76 /*end=*/b_->getInt64(end), 77 /*step=*/b_->getInt64(step), for_body_generator); 78 } 79 For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)80 void For( 81 absl::string_view name, int64 start, int64 end, int64 step, 82 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 83 for_body_generator) { 84 For(name, /*start=*/b_->getInt64(start), 85 /*end=*/b_->getInt64(end), 86 /*step=*/b_->getInt64(step), for_body_generator); 87 } 88 89 // Generates the following control flow structure if `peel_first_iteration` is 90 // true: 91 // 92 // if (`start` < `end`) { 93 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`; 94 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 95 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`; 96 // } 97 // 98 // and the following if `peel_first_iteration` is false: 99 // 100 // for (i64 i = `start`; i s< `end`; i += `step`) 101 // `for_body_generator(/*ind_var=*/,i, 102 // /*is_first_iteration=*/,(i != `start`))`; 103 Status ForWithStatus( 104 absl::string_view name, llvm::Value* start, llvm::Value* end, 105 llvm::Value* step, bool peel_first_iteration, 106 const std::function<Status(llvm::Value* ind_var, 107 llvm::Value* is_first_iteration)>& 108 for_body_generator); 109 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)110 void For(absl::string_view name, llvm::Value* start, llvm::Value* end, 111 llvm::Value* step, bool peel_first_iteration, 112 const std::function<void(llvm::Value* ind_var, 113 llvm::Value* is_first_iteration)>& 114 for_body_generator) { 115 TF_CHECK_OK(ForWithStatus( 116 name, start, end, step, peel_first_iteration, 117 [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { 118 for_body_generator(ind_var, is_first_iteration); 119 return Status::OK(); 120 })); 121 } 122 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<Status (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)123 Status ForWithStatus( 124 absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, 125 bool peel_first_iteration, 126 const std::function<Status(llvm::Value* ind_var, 127 llvm::Value* is_first_iteration)>& 128 for_body_generator) { 129 return ForWithStatus( 130 name, /*start=*/start, /*end=*/end, 131 /*step=*/llvm::ConstantInt::get(start->getType(), step), 132 peel_first_iteration, for_body_generator); 133 } 134 For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)135 void For(absl::string_view name, llvm::Value* start, llvm::Value* end, 136 int64 step, bool peel_first_iteration, 137 const std::function<void(llvm::Value* ind_var, 138 llvm::Value* is_first_iteration)>& 139 for_body_generator) { 140 For(name, /*start=*/start, /*end=*/end, 141 /*step=*/llvm::ConstantInt::get(start->getType(), step), 142 peel_first_iteration, for_body_generator); 143 } 144 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)145 Status ForWithStatus( 146 absl::string_view name, llvm::Value* start, llvm::Value* end, 147 llvm::Value* step, 148 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 149 return ForWithStatus(name, start, end, step, 150 /*peel_first_iteration=*/false, 151 [&](llvm::Value* indvar, llvm::Value*) -> Status { 152 return for_body_generator(indvar); 153 }); 154 } 155 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)156 void For( 157 absl::string_view name, llvm::Value* start, llvm::Value* end, 158 llvm::Value* step, 159 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 160 For(name, start, end, step, 161 /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { 162 return for_body_generator(indvar); 163 }); 164 } 165 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)166 Status ForWithStatus( 167 absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, 168 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 169 return ForWithStatus(name, start, end, 170 llvm::ConstantInt::get(start->getType(), step), 171 /*peel_first_iteration=*/false, 172 [&](llvm::Value* indvar, llvm::Value*) -> Status { 173 return for_body_generator(indvar); 174 }); 175 } 176 For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)177 void For( 178 absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, 179 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 180 For(name, start, end, llvm::ConstantInt::get(start->getType(), step), 181 for_body_generator); 182 } 183 ForWithStatus(absl::string_view name,int64 start,int64 end,int64 step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)184 Status ForWithStatus( 185 absl::string_view name, int64 start, int64 end, int64 step, 186 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 187 return ForWithStatus(name, /*start=*/b_->getInt64(start), 188 /*end=*/b_->getInt64(end), 189 /*step=*/b_->getInt64(step), for_body_generator); 190 } 191 For(absl::string_view name,int64 start,int64 end,int64 step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)192 void For( 193 absl::string_view name, int64 start, int64 end, int64 step, 194 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 195 For(name, /*start=*/b_->getInt64(start), 196 /*end=*/b_->getInt64(end), 197 /*step=*/b_->getInt64(step), for_body_generator); 198 } 199 200 // Generates the following control flow structure: 201 // 202 // if (`condition`) 203 // `true_block_generator()`; 204 // else 205 // `false_block_generator()`; 206 // The else is skipped if false_block_generator is null. 207 Status IfWithStatus( 208 absl::string_view name, llvm::Value* condition, 209 const std::function<Status()>& true_block_generator, 210 const std::function<Status()>& false_block_generator = nullptr); 211 212 Status IfWithStatus( 213 llvm::Value* condition, 214 const std::function<Status()>& true_block_generator, 215 const std::function<Status()>& false_block_generator = []() -> Status { 216 return Status::OK(); 217 }) { 218 return IfWithStatus("", condition, true_block_generator, 219 false_block_generator); 220 } 221 222 void If(llvm::Value* condition, 223 const std::function<void()>& true_block_generator, 224 const std::function<void()>& false_block_generator = nullptr) { 225 If("", condition, true_block_generator, false_block_generator); 226 } 227 228 void If(absl::string_view name, llvm::Value* condition, 229 const std::function<void()>& true_block_generator, 230 const std::function<void()>& false_block_generator = nullptr) { 231 if (false_block_generator != nullptr) { 232 TF_CHECK_OK(IfWithStatus( 233 name, condition, 234 [&]() { 235 true_block_generator(); 236 return Status::OK(); 237 }, 238 [&]() { 239 false_block_generator(); 240 return Status::OK(); 241 })); 242 } else { 243 TF_CHECK_OK(IfWithStatus(name, condition, [&]() { 244 true_block_generator(); 245 return Status::OK(); 246 })); 247 } 248 } 249 250 using ArgumentVector = absl::Span<llvm::Value* const>; 251 252 // Generates the following control flow structure: 253 // 254 // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { 255 // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); 256 // } 257 // 258 // ... 259 // call @`kernel_name`(arguments[0], arguments[1] ...) 260 // ... 261 // 262 // If a function called `kernel_name` is already present in the module then 263 // that function is re-used. In that sense we're using the llvm::Module as a 264 // cache of outlined kernels, keyed by function name. 265 // 266 // If any of the values in `arguments` is nullptr (i.e. a nullptr 267 // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass 268 // in a nullptr llvm::Value* in its position to `kernel_body_generator`. 269 // Currently we only support at most one nullptr value in `arguments`. 270 static void EmitAndCallOutlinedKernel( 271 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 272 absl::string_view kernel_name, ArgumentVector arguments, 273 const std::function<void(ArgumentVector)>& kernel_body_generator); 274 275 // Thin wrappers around the more general EmitAndCallOutlinedKernel above. EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)276 static void EmitAndCallOutlinedKernel( 277 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 278 absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, 279 llvm::Value* arg2, 280 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>& 281 kernel_body_generator) { 282 EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2}, 283 [&](ArgumentVector args) { 284 kernel_body_generator(args[0], args[1], 285 args[2]); 286 }); 287 } 288 EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)289 static void EmitAndCallOutlinedKernel( 290 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 291 absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, 292 llvm::Value* arg2, llvm::Value* arg3, 293 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*, 294 llvm::Value*)>& kernel_body_generator) { 295 EmitAndCallOutlinedKernel( 296 module_config, b, kernel_name, {arg0, arg1, arg2, arg3}, 297 [&](ArgumentVector args) { 298 kernel_body_generator(args[0], args[1], args[2], args[3]); 299 }); 300 } 301 302 private: 303 llvm::IRBuilder<>* b_; 304 llvm_ir::UnrollMode unroll_mode_; 305 bool prevent_vectorization_; 306 }; 307 } // namespace xla 308 309 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 310