1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_RUNTIME_RT_PASSES_H_ 17 #define XLA_MLIR_RUNTIME_RT_PASSES_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project 23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 24 #include "mlir/Pass/Pass.h" // from @llvm-project 25 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project 26 #include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h" 27 28 namespace xla { 29 namespace runtime { 30 31 //===-----------------------------------------------------------------------===/ 32 // Transformations targeting `rt` dialect. 33 //===-----------------------------------------------------------------------===/ 34 35 static constexpr char const* kEntrypointAttrName = "rt.entrypoint"; 36 37 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 38 CreateConvertToEntrypoint(); 39 40 //===-----------------------------------------------------------------------===/ 41 // Conversions targeting `rt` dialect. 42 //===-----------------------------------------------------------------------===/ 43 44 class TypeIDNameRegistry; 45 class CustomCallArgEncodingSet; 46 class CustomCallAttrEncodingSet; 47 48 // Extension points for converting `rt` dialect to the LLVM dialect. 49 // 50 // Runtime custom calls is an extension mechanism for enabling compiled programs 51 // to call into the APIs provided by the user. It relies on converting 52 // values and attributes to the LLVM types (structs and pointers) with a 53 // well-defined memory layout, so that they can be passed across the function 54 // boundary and safely decoded (without dependency on C++ ABI). 55 // 56 // All user-defined types (values and attributes) that are passed to the custom 57 // calls must define the argument or attribute encoding. 58 struct ConvertRuntimeToLLvmOpts { 59 // Register names for the TypeIDs used for encoding types of custom arguments 60 // and attributes. 61 std::function<void(TypeIDNameRegistry&)> populate_type_id_names; 62 63 // Add type conversions for user-defined types to the corresponding LLVM 64 // types. Conversion pass uses these extra conversions to convert arguments 65 // of the entrypoint function and values passed to the custom calls. Custom 66 // call argument encoding can further refine how values of LLVM types passed 67 // to the custom call handlers by passing custom encoding (see below). 68 std::function<void(mlir::TypeConverter&)> populate_type_conversions; 69 70 // Add user-defined arguments encoding to the custom call lowering. 71 std::function<void(CustomCallArgEncodingSet&)> populate_arg_encodings; 72 73 // Add user-defined attributes type encoding to the custom call lowering. 74 std::function<void(CustomCallAttrEncodingSet&)> populate_attr_encodings; 75 }; 76 77 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 78 CreateConvertRuntimeToLLVMPass(ConvertRuntimeToLLvmOpts opts = {}); 79 80 //===-----------------------------------------------------------------------===/ 81 82 #define GEN_PASS_REGISTRATION 83 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h.inc" 84 85 } // namespace runtime 86 } // namespace xla 87 88 #endif // XLA_MLIR_RUNTIME_RT_PASSES_H_ 89