• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_RUNTIME_RT_PASSES_H_
17 #define XLA_MLIR_RUNTIME_RT_PASSES_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
26 #include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
27 
28 namespace xla {
29 namespace runtime {
30 
31 //===-----------------------------------------------------------------------===/
32 // Transformations targeting `rt` dialect.
33 //===-----------------------------------------------------------------------===/
34 
35 static constexpr char const* kEntrypointAttrName = "rt.entrypoint";
36 
37 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
38 CreateConvertToEntrypoint();
39 
40 //===-----------------------------------------------------------------------===/
41 // Conversions targeting `rt` dialect.
42 //===-----------------------------------------------------------------------===/
43 
44 class TypeIDNameRegistry;
45 class CustomCallArgEncodingSet;
46 class CustomCallAttrEncodingSet;
47 
48 // Extension points for converting `rt` dialect to the LLVM dialect.
49 //
50 // Runtime custom calls is an extension mechanism for enabling compiled programs
51 // to call into the APIs provided by the user. It relies on converting
52 // values and attributes to the LLVM types (structs and pointers) with a
53 // well-defined memory layout, so that they can be passed across the function
54 // boundary and safely decoded (without dependency on C++ ABI).
55 //
56 // All user-defined types (values and attributes) that are passed to the custom
57 // calls must define the argument or attribute encoding.
58 struct ConvertRuntimeToLLvmOpts {
59   // Register names for the TypeIDs used for encoding types of custom arguments
60   // and attributes.
61   std::function<void(TypeIDNameRegistry&)> populate_type_id_names;
62 
63   // Add type conversions for user-defined types to the corresponding LLVM
64   // types. Conversion pass uses these extra conversions to convert arguments
65   // of the entrypoint function and values passed to the custom calls. Custom
66   // call argument encoding can further refine how values of LLVM types passed
67   // to the custom call handlers by passing custom encoding (see below).
68   std::function<void(mlir::TypeConverter&)> populate_type_conversions;
69 
70   // Add user-defined arguments encoding to the custom call lowering.
71   std::function<void(CustomCallArgEncodingSet&)> populate_arg_encodings;
72 
73   // Add user-defined attributes type encoding to the custom call lowering.
74   std::function<void(CustomCallAttrEncodingSet&)> populate_attr_encodings;
75 };
76 
77 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
78 CreateConvertRuntimeToLLVMPass(ConvertRuntimeToLLvmOpts opts = {});
79 
80 //===-----------------------------------------------------------------------===/
81 
82 #define GEN_PASS_REGISTRATION
83 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h.inc"
84 
85 }  // namespace runtime
86 }  // namespace xla
87 
88 #endif  // XLA_MLIR_RUNTIME_RT_PASSES_H_
89