• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
18 
19 #include <memory>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassManager.h"  // from @llvm-project
26 #include "tensorflow/compiler/tf2xla/xla_argument.h"
27 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 
34 namespace tensorflow {
35 
36 // Populates the supplied passmanager with the passes required to run the
37 // TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes
38 // can be populated in `custom_legalization_passes`.
39 void CreateConvertMlirToXlaHloPipeline(
40     mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla,
41     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
42         custom_legalization_passes);
43 
44 // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
45 // should only contain operations in tf dialect. If the input module contains
46 // operation in the tf_executor dialect, for example, returns an error.
47 // Exception to this are tf_executor dialect ops that are optimized away through
48 // canonicalization.
49 //
50 // Operations in tf dialect are lowered to XLA HLO through the following steps:
51 //   . Legalizes control flow operations.
52 //   . Decomposes compound resource operations so that the only remaining
53 //     operations on resource variables are resource reads/writes..
54 //   . Replaces resource reads/writes with function inputs/outputs and
55 //     eliminates the use of resource variables.
56 //   . Legalizes the operations to XLA HLO operations.
57 //   . Canonicalizes the XLA HLO operations.
58 //
59 // device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT",
60 //   "XLA_GPU_JIT" or "XLA_TPU_JIT".
61 // use_tuple_args: when this is true, always create a tuple argument for the
62 //   entry computation.
63 // prefer_tf2xla: when this is true, prefer tf2xla fallback kernels over MLIR
64 //   native kernels for legalization to HLO.
65 // return_tuple: when this is true, always create a tuple result for the
66 //   entry computation.
67 // shape_representation_fn: when this is set, this shape representation function
68 //   will be used to determine argument and result shapes. Otherwise the
69 //   original shape will be used as is.
70 // custom_legalization_passes: passes to run before the default TF legalization
71 //   passes for backend-specific ops.
72 //
73 // TODO(hinsu): Migrate options to a separate struct.
74 Status ConvertMLIRToXlaComputation(
75     mlir::ModuleOp module_op, llvm::StringRef device_type,
76     xla::XlaComputation* xla_computation, bool use_tuple_args,
77     bool prefer_tf2xla, bool return_tuple,
78     const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
79     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
80         custom_legalization_passes = {});
81 
82 // Helper struct representing argument tensor or resource handle shapes.
83 struct TensorOrResourceShape {
84   TensorShape shape;
85   bool is_resource = false;
86 };
87 
88 // Refine MLIR types based on new shape information.
89 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
90                     mlir::ModuleOp module);
91 
92 // Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
93 // inputs to module_op that have already been added to the XlaBuilder. returns
94 // are the returned XlaOps.
95 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
96                       llvm::ArrayRef<xla::XlaOp> xla_params,
97                       std::vector<xla::XlaOp>& returns,
98                       llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
99                       llvm::StringRef device_type,
100                       llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
101                           custom_legalization_passes);
102 
103 // Apply shape, description, and resource information to inputs and outputs
104 // in the XlaCompilationResult. This should be called after
105 // compilation_result->computation was set.
106 Status PopulateResultIOInfo(
107     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
108     bool use_tuple_args, bool use_resource_updates_for_aliases,
109     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
110     XlaCompilationResult* compilation_result);
111 
112 // Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
113 // stores them in CompilationResult.
114 //
115 // If analyse_graph is set to true, graph is legalized only if the graph
116 // analysis for the graph is successful. Otherwise, an error is returned.
117 Status CompileMlirToXlaHlo(
118     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
119     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
120     bool use_return_tuple, bool use_resource_updates_for_aliases,
121     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
122     XlaCompilationResult* compilation_result,
123     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
124         custom_legalization_passes);
125 
126 // Compiles a serialized MLIR module into XLA HLO, generates all accompanying
127 // metadata and stores them in CompilationResult.
128 Status CompileSerializedMlirToXlaHlo(
129     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
130     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
131     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
132     XlaCompilationResult* compilation_result,
133     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
134         custom_legalization_passes = {});
135 
136 // Compiles a TensorFlow Graph (already converted to MLIR, imported with
137 // tf_executor dialect still present) into XLA HLO, generates all accompanying
138 // metadata and stores them in CompilationResult. This will rewrite arguments
139 // and run the TensorFlow standard pipeline prior to invoking
140 // `CompileMlirToXlaHlo`.
141 Status CompileGraphToXlaHlo(
142     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
143     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
144     bool use_return_tuple,
145     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
146     XlaCompilationResult* compilation_result,
147     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
148         custom_legalization_passes);
149 
150 // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
151 // and stores them in CompilationResult.
152 Status CompileGraphToXlaHlo(
153     const Graph& graph, llvm::ArrayRef<XlaArgument> args,
154     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
155     bool use_tuple_args, bool analyse_graph,
156     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
157     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
158     XlaCompilationResult* compilation_result,
159     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
160         custom_legalization_passes = {});
161 
162 // Compiles a Graph from TF to HLO and adds the resulting HLO to the
163 // XlaBuilder. This function adds HLO to a larger HLO computation, so
164 // HLO-level inputs are supplied, and HLO-level outputs are produced.
165 // xla_params is the HLO-level inputs and returns is the HLO-level outputs.
166 Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
167                          llvm::ArrayRef<xla::XlaOp> xla_params,
168                          std::vector<xla::XlaOp>& returns,
169                          llvm::ArrayRef<XlaArgument> args,
170                          llvm::ArrayRef<std::string> control_rets,
171                          llvm::StringRef device_type,
172                          const FunctionLibraryDefinition& flib_def,
173                          const GraphDebugInfo& debug_info,
174                          llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
175                              custom_legalization_passes = {});
176 
CompileToHloGraphAnalysisFailedError()177 static inline Status CompileToHloGraphAnalysisFailedError() {
178   return errors::Internal("disabled after graph analysis");
179 }
180 }  // namespace tensorflow
181 
182 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
183