1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
17
18 #include <memory>
19
20 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
21 #include "mlir/Pass/PassManager.h" // from @llvm-project
22 #include "mlir/Transforms/Passes.h" // from @llvm-project
23 #include "tensorflow/compiler/jit/flags.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
28
29 namespace mlir {
30 namespace {
31 // Add logger to bridge passmanager.
32 // Enable timing statistics per pass for the bridge passmanager.
EnableDetailedLogging(PassManager * pm)33 void EnableDetailedLogging(PassManager *pm) {
34 // Print the whole module after each pass, which requires disabling
35 // multi-threading as well.
36 pm->getContext()->disableMultithreading();
37 pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
38 /*print_module_scope=*/true));
39 pm->enableTiming();
40 }
41 } // namespace
42
43 namespace TFTPU {
44
45 namespace {
RunTPUBridge(ModuleOp module,bool enable_logging,llvm::function_ref<void (OpPassManager & pm)> pipeline_builder)46 tensorflow::Status RunTPUBridge(
47 ModuleOp module, bool enable_logging,
48 llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
49 PassManager bridge(module.getContext());
50 ::tensorflow::applyTensorflowAndCLOptions(bridge);
51 if (enable_logging || VLOG_IS_ON(1)) {
52 tensorflow::DumpMlirOpToFile("tpu_bridge_before", module);
53 if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
54 }
55
56 // Populate a passmanager with the list of passes that implement the bridge.
57 pipeline_builder(bridge);
58
59 // Add set of passes to lower back to graph (from tf_executor).
60 TF::AddGraphExportLoweringPasses(bridge);
61
62 // Run the bridge on the module, in case of failure, the `diag_handler`
63 // converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
64 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
65 LogicalResult result = bridge.run(module);
66 (void)result;
67 if (enable_logging || VLOG_IS_ON(1))
68 tensorflow::DumpMlirOpToFile("tpu_bridge_after", module);
69 return diag_handler.ConsumeStatus();
70 }
71 } // namespace
72
CreateTPUBridgePipeline(OpPassManager & pm)73 void CreateTPUBridgePipeline(OpPassManager &pm) {
74 // The following ops must be preserved regardless of reachability. Ideally,
75 // all graphs should have control dependencies to enforce this but this is
76 // currently not the case (see b/177478741).
77 const llvm::SmallVector<std::string, 4> ops_to_preserve = {
78 "tf.TPUReplicateMetadata", "tf.TPUCompilationResult",
79 "tf.TPUReplicatedOutput"};
80 pm.addNestedPass<FuncOp>(
81 tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
82 // It is assumed at this stage there are no V1 control flow ops as Graph
83 // functionalization is ran before import. Ops can be lifted out of
84 // tf_executor dialect islands/graphs.
85 pm.addNestedPass<FuncOp>(CreateExecutorDialectToFunctionalConversionPass());
86 // Guarantee all functions have one use, which enables more exact shape
87 // inference.
88 pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
89 // Run shape inference so that tf_executor/tf_device ops created later will
90 // likely to inherit more concrete types.
91 pm.addPass(TF::CreateTFShapeInferencePass());
92 pm.addNestedPass<FuncOp>(CreateTPUReorderReplicateAndPartitionedInputsPass());
93 pm.addPass(CreateTPUClusterFormationPass());
94 pm.addPass(CreateOutsideCompiledToHostLaunchPass());
95 pm.addNestedPass<FuncOp>(TFDevice::CreateDeviceAttributeToLaunchPass());
96 // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
97 // because DecomposeResourceOpsPass uses pattern rewriter which hoists
98 // changed constants out of tf_device.Launch.
99 pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
100 // Encode this in its own scope so that func_pm is not mistakenly used
101 // later on.
102 {
103 OpPassManager &func_pm = pm.nest<FuncOp>();
104 func_pm.addPass(CreateTPUHostComputationExpansionPass());
105 func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
106 }
107 // TODO(b/173622615): This should incrementally be moved down as
108 // more passes support this representation and then can be removed once
109 // all passes support it.
110 pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
111
112 // TODO(b/173622615): Once OutsideCompilation is represented by launch op and
113 // the remaining passes including Inliner support it, remove this
114 // LaunchToDeviceAttributePass. This LaunchToDeviceAttribute pass needs to
115 // come before TPUClusterCleanupAttributes pass or else the device attribute
116 // will be removed from launch causing an error.
117 pm.addNestedPass<FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
118
119 // Note that the region-based control-flow produced here still contains
120 // function call ops which get inlined by the subsequent inliner pass.
121 pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
122 pm.addPass(CreateOutsideCompiledToHostLaunchPass());
123 pm.addPass(mlir::createInlinerPass());
124 pm.addNestedPass<FuncOp>(
125 TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
126 // Run another shape inference pass because resource decomposition might have
127 // created new partial types. Also, after dropping `shape_invariant` attribute
128 // from While/WhileRegion ops within cluster would lead to more precise
129 // shapes.
130 pm.addPass(TF::CreateTFShapeInferencePass());
131 pm.addNestedPass<FuncOp>(createCanonicalizerPass());
132 pm.addPass(CreateTPUClusterCleanupAttributesPass());
133 // TODO(b/173622615): This should incrementally be moved down as
134 // more passes support this representation and then can be removed once
135 // all passes support it.
136 pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
137 pm.addPass(TFDevice::CreateResourceOpLiftingPass());
138 // Re-run the canonicalizer pass as some cleanup during resource op lifting
139 // pass opens up some opportunities for canonicalization of cluster ops.
140 // Specifically, we want to eliminate pass through results from the cluster
141 // op.
142 pm.addNestedPass<FuncOp>(createCanonicalizerPass());
143 pm.addNestedPass<FuncOp>(createCSEPass());
144 if (tensorflow::GetMlirCommonFlags()
145 ->tf_mlir_enable_merge_control_flow_pass) {
146 pm.addPass(TFDevice::CreateMergeControlFlowPass());
147 }
148
149 pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
150 pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
151 pm.addPass(CreateTPUExtractOutsideCompilationPass());
152
153 pm.addNestedPass<FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
154 pm.addPass(TF::CreateResourceDeviceInferencePass());
155 pm.addPass(TFDevice::CreateClusterOutliningPass());
156 pm.addPass(CreateTPUResourceReadForWritePass());
157 pm.addPass(CreateTPUShardingIdentificationPass());
158 pm.addNestedPass<FuncOp>(CreateTPUResourceReadsWritesPartitioningPass());
159 pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
160 pm.addPass(TFDevice::CreateMarkInputOutputAliasesPass());
161 pm.addPass(CreateTPURewritePass());
162 pm.addPass(createSymbolDCEPass());
163 pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
164 pm.addPass(CreateTPUMergeVariablesWithExecutePass());
165 pm.addNestedPass<FuncOp>(
166 TF::CreateHoistReplicateInvariantResourceWritesPass());
167 pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
168 pm.addPass(CreateTPUVariableReformattingPass());
169 pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
170 }
171
CreateTPUBridgePipelineV1(OpPassManager & pm)172 void CreateTPUBridgePipelineV1(OpPassManager &pm) {
173 // Guarantee all functions have one use, which enables more exact shape
174 // inference.
175 pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
176 pm.addPass(TF::CreateTFShapeInferencePass());
177 // For V1 compatibility, we process a module where the graph does not have
178 // feeds and fetched. We extract first the TPU computation in a submodule,
179 // where it'll be in a function with args and returned values, much more like
180 // a TF v2 module. We can then run the usual pipeline on this nested module.
181 // Afterward we inline back in the parent module and delete the nested one.
182 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
183 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
184 OpPassManager &nested_module = pm.nest<ModuleOp>();
185 CreateTPUBridgePipeline(nested_module);
186 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
187 }
188
TPUBridge(ModuleOp module,bool enable_logging)189 tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
190 return RunTPUBridge(module, enable_logging, CreateTPUBridgePipeline);
191 }
TPUBridgeV1Compat(ModuleOp module,bool enable_logging)192 tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) {
193 return RunTPUBridge(module, enable_logging, CreateTPUBridgePipelineV1);
194 }
195
196 } // namespace TFTPU
197
198 namespace TF {
199
AddGraphExportLoweringPasses(OpPassManager & pm)200 void AddGraphExportLoweringPasses(OpPassManager &pm) {
201 auto add_pass = [&](std::unique_ptr<Pass> pass) {
202 pm.addNestedPass<FuncOp>(std::move(pass));
203 pm.addPass(CreateBreakUpIslandsPass());
204 };
205
206 add_pass(CreateFunctionalToExecutorDialectConversionPass());
207 add_pass(TFDevice::CreateReplicateToIslandPass());
208 add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
209 add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
210 pm.addNestedPass<FuncOp>(TFTPU::CreateTPUDevicePropagationPass());
211 pm.addPass(createSymbolDCEPass());
212 pm.addPass(CreateVerifySuitableForExportPass());
213 }
214
RunBridgeWithStandardPipeline(ModuleOp module,bool enable_logging,bool enable_inliner)215 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
216 bool enable_logging,
217 bool enable_inliner) {
218 PassManager bridge(module.getContext());
219 if (enable_logging || VLOG_IS_ON(1)) {
220 tensorflow::DumpMlirOpToFile("standard_pipeline_before", module);
221 if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
222 }
223
224 StandardPipelineOptions pipeline_options;
225 pipeline_options.enable_inliner.setValue(enable_inliner);
226 CreateTFStandardPipeline(bridge, pipeline_options);
227 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
228 LogicalResult result = bridge.run(module);
229 (void)result;
230 if (enable_logging || VLOG_IS_ON(1))
231 tensorflow::DumpMlirOpToFile("standard_pipeline_after", module);
232 return diag_handler.ConsumeStatus();
233 }
234
235 } // namespace TF
236 } // namespace mlir
237