• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_TRANSFORMS_RUNTIME_COMPILATION_PIPELINE_H_
17 #define XLA_MLIR_TRANSFORMS_RUNTIME_COMPILATION_PIPELINE_H_
18 
19 #include <functional>
20 
21 #include "mlir/Pass/PassManager.h"  // from @llvm-project
22 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
23 #include "tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.h"
24 #include "tensorflow/compiler/xla/runtime/type_id.h"
25 
26 namespace xla {
27 namespace runtime {
28 
29 struct CompilationPipelineOptions {
30   // Register names for the TypeIDs used for encoding types of custom arguments
31   // and attributes.
32   std::function<void(TypeIDNameRegistry&)> populate_type_id_names;
33 
34   // Add type conversions from user-defined types to LLVM types. These
35   // conversions are required for lowering runtime operations to the
36   // corresponding runtime APIs (including custom calls).
37   std::function<void(mlir::TypeConverter&)> populate_type_conversions;
38 
39   // Add user-defined encoding for JitRt custom call arguments and attributes.
40   //
41   // Custom encodings allow to pass dialect-specific attributes (enums and
42   // structs) to the custom calls, and decode them into dialect-specific runtime
43   // values in the custom call handlers (see custom_call_to_llvm.h for details).
44   std::function<void(CustomCallArgEncodingSet&)> populate_arg_encodings;
45   std::function<void(CustomCallAttrEncodingSet&)> populate_attr_encodings;
46 };
47 
48 // Registers dialects, interfaces and dialects translations with the registry
49 // required by the default XLA runtime compilation pipeline.
50 void RegisterDefaultXlaRuntimeDialects(mlir::DialectRegistry& registry);
51 
52 // Creates default XLA runtime compilation pipeline that lowers from the `rt`
53 // and `memref` dialects to the LLVMIR dialect. This is a very simple pipeline
54 // that is mostly intended for writing tests for the XLA runtime, and it is
55 // expected that all end users will construct their own compilation pipelines
56 // from the available XLA and MLIR passes.
57 void CreateDefaultXlaRuntimeCompilationPipeline(
58     mlir::OpPassManager& pm, const CompilationPipelineOptions& opts);
59 
60 }  // namespace runtime
61 }  // namespace xla
62 
63 #endif  // XLA_MLIR_TRANSFORMS_RUNTIME_COMPILATION_PIPELINE_H_
64