1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_TRANSFORMS_RUNTIME_COMPILATION_PIPELINE_H_ 17 #define XLA_MLIR_TRANSFORMS_RUNTIME_COMPILATION_PIPELINE_H_ 18 19 #include <functional> 20 21 #include "mlir/Pass/PassManager.h" // from @llvm-project 22 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project 23 #include "tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.h" 24 #include "tensorflow/compiler/xla/runtime/type_id.h" 25 26 namespace xla { 27 namespace runtime { 28 29 struct CompilationPipelineOptions { 30 // Register names for the TypeIDs used for encoding types of custom arguments 31 // and attributes. 32 std::function<void(TypeIDNameRegistry&)> populate_type_id_names; 33 34 // Add type conversions from user-defined types to LLVM types. These 35 // conversions are required for lowering runtime operations to the 36 // corresponding runtime APIs (including custom calls). 37 std::function<void(mlir::TypeConverter&)> populate_type_conversions; 38 39 // Add user-defined encoding for JitRt custom call arguments and attributes. 40 // 41 // Custom encodings allow to pass dialect-specific attributes (enums and 42 // structs) to the custom calls, and decode them into dialect-specific runtime 43 // values in the custom call handlers (see custom_call_to_llvm.h for details). 44 std::function<void(CustomCallArgEncodingSet&)> populate_arg_encodings; 45 std::function<void(CustomCallAttrEncodingSet&)> populate_attr_encodings; 46 }; 47 48 // Registers dialects, interfaces and dialects translations with the registry 49 // required by the default XLA runtime compilation pipeline. 50 void RegisterDefaultXlaRuntimeDialects(mlir::DialectRegistry& registry); 51 52 // Creates default XLA runtime compilation pipeline that lowers from the `rt` 53 // and `memref` dialects to the LLVMIR dialect. This is a very simple pipeline 54 // that is mostly intended for writing tests for the XLA runtime, and it is 55 // expected that all end users will construct their own compilation pipelines 56 // from the available XLA and MLIR passes. 57 void CreateDefaultXlaRuntimeCompilationPipeline( 58 mlir::OpPassManager& pm, const CompilationPipelineOptions& opts); 59 60 } // namespace runtime 61 } // namespace xla 62 63 #endif // XLA_MLIR_TRANSFORMS_RUNTIME_COMPILATION_PIPELINE_H_ 64