1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
18
19 #include <memory>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassManager.h" // from @llvm-project
26 #include "tensorflow/compiler/tf2xla/xla_argument.h"
27 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33
34 namespace tensorflow {
35
36 // Populates the supplied passmanager with the passes required to run the
37 // TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes
38 // can be populated in `custom_legalization_passes`.
39 void CreateConvertMlirToXlaHloPipeline(
40 mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla,
41 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
42 custom_legalization_passes);
43
44 // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
45 // should only contain operations in tf dialect. If the input module contains
46 // operation in the tf_executor dialect, for example, returns an error.
47 // Exception to this are tf_executor dialect ops that are optimized away through
48 // canonicalization.
49 //
50 // Operations in tf dialect are lowered to XLA HLO through the following steps:
51 // . Legalizes control flow operations.
52 // . Decomposes compound resource operations so that the only remaining
53 // operations on resource variables are resource reads/writes..
54 // . Replaces resource reads/writes with function inputs/outputs and
55 // eliminates the use of resource variables.
56 // . Legalizes the operations to XLA HLO operations.
57 // . Canonicalizes the XLA HLO operations.
58 //
59 // device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT",
60 // "XLA_GPU_JIT" or "XLA_TPU_JIT".
61 // use_tuple_args: when this is true, always create a tuple argument for the
62 // entry computation.
63 // prefer_tf2xla: when this is true, prefer tf2xla fallback kernels over MLIR
64 // native kernels for legalization to HLO.
65 // return_tuple: when this is true, always create a tuple result for the
66 // entry computation.
67 // shape_representation_fn: when this is set, this shape representation function
68 // will be used to determine argument and result shapes. Otherwise the
69 // original shape will be used as is.
70 // custom_legalization_passes: passes to run before the default TF legalization
71 // passes for backend-specific ops.
72 //
73 // TODO(hinsu): Migrate options to a separate struct.
74 Status ConvertMLIRToXlaComputation(
75 mlir::ModuleOp module_op, llvm::StringRef device_type,
76 xla::XlaComputation* xla_computation, bool use_tuple_args,
77 bool prefer_tf2xla, bool return_tuple,
78 const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
79 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
80 custom_legalization_passes = {});
81
82 // Helper struct representing argument tensor or resource handle shapes.
83 struct TensorOrResourceShape {
84 TensorShape shape;
85 bool is_resource = false;
86 };
87
88 // Refine MLIR types based on new shape information.
89 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
90 mlir::ModuleOp module);
91
92 // Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
93 // inputs to module_op that have already been added to the XlaBuilder. returns
94 // are the returned XlaOps.
95 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
96 llvm::ArrayRef<xla::XlaOp> xla_params,
97 std::vector<xla::XlaOp>& returns,
98 llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
99 llvm::StringRef device_type,
100 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
101 custom_legalization_passes);
102
103 // Apply shape, description, and resource information to inputs and outputs
104 // in the XlaCompilationResult. This should be called after
105 // compilation_result->computation was set.
106 Status PopulateResultIOInfo(
107 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
108 bool use_tuple_args, bool use_resource_updates_for_aliases,
109 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
110 XlaCompilationResult* compilation_result);
111
112 // Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
113 // stores them in CompilationResult.
114 //
115 // If analyse_graph is set to true, graph is legalized only if the graph
116 // analysis for the graph is successful. Otherwise, an error is returned.
117 Status CompileMlirToXlaHlo(
118 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
119 llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
120 bool use_return_tuple, bool use_resource_updates_for_aliases,
121 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
122 XlaCompilationResult* compilation_result,
123 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
124 custom_legalization_passes);
125
126 // Compiles a serialized MLIR module into XLA HLO, generates all accompanying
127 // metadata and stores them in CompilationResult.
128 Status CompileSerializedMlirToXlaHlo(
129 llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
130 llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
131 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
132 XlaCompilationResult* compilation_result,
133 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
134 custom_legalization_passes = {});
135
136 // Compiles a TensorFlow Graph (already converted to MLIR, imported with
137 // tf_executor dialect still present) into XLA HLO, generates all accompanying
138 // metadata and stores them in CompilationResult. This will rewrite arguments
139 // and run the TensorFlow standard pipeline prior to invoking
140 // `CompileMlirToXlaHlo`.
141 Status CompileGraphToXlaHlo(
142 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
143 llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
144 bool use_return_tuple,
145 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
146 XlaCompilationResult* compilation_result,
147 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
148 custom_legalization_passes);
149
150 // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
151 // and stores them in CompilationResult.
152 Status CompileGraphToXlaHlo(
153 const Graph& graph, llvm::ArrayRef<XlaArgument> args,
154 llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
155 bool use_tuple_args, bool analyse_graph,
156 const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
157 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
158 XlaCompilationResult* compilation_result,
159 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
160 custom_legalization_passes = {});
161
162 // Compiles a Graph from TF to HLO and adds the resulting HLO to the
163 // XlaBuilder. This function adds HLO to a larger HLO computation, so
164 // HLO-level inputs are supplied, and HLO-level outputs are produced.
165 // xla_params is the HLO-level inputs and returns is the HLO-level outputs.
166 Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
167 llvm::ArrayRef<xla::XlaOp> xla_params,
168 std::vector<xla::XlaOp>& returns,
169 llvm::ArrayRef<XlaArgument> args,
170 llvm::ArrayRef<std::string> control_rets,
171 llvm::StringRef device_type,
172 const FunctionLibraryDefinition& flib_def,
173 const GraphDebugInfo& debug_info,
174 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
175 custom_legalization_passes = {});
176
CompileToHloGraphAnalysisFailedError()177 static inline Status CompileToHloGraphAnalysisFailedError() {
178 return errors::Internal("disabled after graph analysis");
179 }
180 } // namespace tensorflow
181
182 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
183