1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
17
18 #include <memory>
19
20 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
21 #include "mlir/Pass/PassManager.h" // from @llvm-project
22 #include "mlir/Transforms/Passes.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
24 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
27
28 namespace mlir {
29 namespace {
30 // Add logger to bridge passmanager.
EnableLogging(PassManager * pm)31 void EnableLogging(PassManager *pm) {
32 // Print the whole module after each pass, which requires disabling
33 // multi-threading as well.
34 pm->getContext()->disableMultithreading();
35 pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
36 /*print_module_scope=*/true));
37 pm->enableTiming(std::make_unique<tensorflow::BridgeTimingConfig>());
38 }
39 } // namespace
40
41 namespace TFTPU {
42 namespace {
AddGraphExportLoweringPasses(OpPassManager & pm)43 void AddGraphExportLoweringPasses(OpPassManager &pm) {
44 auto add_pass = [&](std::unique_ptr<Pass> pass) {
45 pm.addNestedPass<FuncOp>(std::move(pass));
46 pm.addPass(CreateBreakUpIslandsPass());
47 };
48
49 add_pass(CreateFunctionalToExecutorDialectConversionPass());
50 add_pass(TFDevice::CreateReplicateToIslandPass());
51 add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
52 add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
53 pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass());
54 pm.addPass(createSymbolDCEPass());
55 }
56
RunTPUBridge(ModuleOp module,bool enable_logging,llvm::function_ref<void (OpPassManager & pm)> pipeline_builder)57 tensorflow::Status RunTPUBridge(
58 ModuleOp module, bool enable_logging,
59 llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
60 PassManager bridge(module.getContext());
61 ::tensorflow::applyTensorflowAndCLOptions(bridge);
62 if (enable_logging || VLOG_IS_ON(1)) {
63 tensorflow::DumpMlirOpToFile("tpu_bridge_before", module);
64 if (VLOG_IS_ON(2)) EnableLogging(&bridge);
65 }
66
67 // Populate a passmanager with the list of passes that implement the bridge.
68 pipeline_builder(bridge);
69
70 // Add set of passes to lower back to graph (from tf_executor).
71 AddGraphExportLoweringPasses(bridge);
72
73 // Run the bridge on the module, in case of failure, the `diag_handler`
74 // converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
75 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
76 LogicalResult result = bridge.run(module);
77 (void)result;
78 if (enable_logging || VLOG_IS_ON(1))
79 tensorflow::DumpMlirOpToFile("tpu_bridge_after", module);
80 return diag_handler.ConsumeStatus();
81 }
82 } // namespace
83
CreateTPUBridgePipeline(OpPassManager & pm)84 void CreateTPUBridgePipeline(OpPassManager &pm) {
85 // The following ops must be preserved regardless of reachability. Ideally,
86 // all graphs should have control dependencies to enforce this but this is
87 // currently not the case (see b/177478741).
88 const llvm::SmallVector<std::string, 4> ops_to_preserve = {
89 "tf.TPUReplicateMetadata", "tf.TPUCompilationResult",
90 "tf.TPUReplicatedInput", "tf.TPUReplicatedOutput"};
91 pm.addNestedPass<FuncOp>(
92 tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
93 // It is assumed at this stage there are no V1 control flow ops as Graph
94 // functionalization is ran before import. Ops can be lifted out of
95 // tf_executor dialect islands/graphs.
96 pm.addNestedPass<FuncOp>(CreateExecutorDialectToFunctionalConversionPass());
97 // Run shape inference so that tf_executor/tf_device ops created later will
98 // likely to inherit more concrete types.
99 pm.addPass(TF::CreateTFShapeInferencePass());
100 pm.addNestedPass<FuncOp>(CreateTPUReorderReplicateAndPartitionedInputsPass());
101 // Encode this in its own scope so that func_pm is not mistakenly used
102 // later on.
103 {
104 pm.addPass(CreateTPUClusterFormationPass());
105 OpPassManager &func_pm = pm.nest<FuncOp>();
106 // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
107 // because DecomposeResourceOpsPass uses pattern rewriter which hoists
108 // changed constants out of tf_device.Launch.
109 func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
110 func_pm.addPass(CreateTPUHostComputationExpansionPass());
111 func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
112 }
113 // Run another shape inference pass because resource decomposition might have
114 // created new partial types.
115 pm.addPass(TF::CreateTFShapeInferencePass());
116 pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
117 pm.addPass(mlir::createInlinerPass());
118 pm.addPass(CreateTPUClusterCleanupAttributesPass());
119 pm.addPass(TFDevice::CreateResourceOpLiftingPass());
120 pm.addNestedPass<FuncOp>(createCSEPass());
121 pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
122 pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
123 pm.addPass(CreateTPUExtractOutsideCompilationPass());
124
125 pm.addNestedPass<FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
126 pm.addPass(TF::CreateResourceDeviceInferencePass());
127 pm.addPass(TFDevice::CreateClusterOutliningPass());
128 pm.addPass(CreateTPUDynamicPaddingMapperPass());
129 pm.addPass(CreateTPUResourceReadForWritePass());
130 pm.addPass(CreateTPUShardingIdentificationPass());
131 pm.addNestedPass<FuncOp>(CreateTPUResourceReadsWritesPartitioningPass());
132 pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
133 pm.addPass(CreateTPURewritePass());
134 pm.addPass(createSymbolDCEPass());
135 pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
136 pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
137 pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
138 pm.addPass(CreateTPUVariableReformattingPass());
139 pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
140 }
141
CreateTPUBridgePipelineV1(OpPassManager & pm)142 void CreateTPUBridgePipelineV1(OpPassManager &pm) {
143 pm.addPass(TF::CreateTFShapeInferencePass());
144 // For V1 compatibility, we process a module where the graph does not have
145 // feeds and fetched. We extract first the TPU computation in a submodule,
146 // where it'll be in a function with args and returned values, much more like
147 // a TF v2 module. We can then run the usual pipeline on this nested module.
148 // Afterward we inline back in the parent module and delete the nested one.
149 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
150 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
151 OpPassManager &nested_module = pm.nest<ModuleOp>();
152 CreateTPUBridgePipeline(nested_module);
153 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
154 }
155
TPUBridge(ModuleOp module,bool enable_logging)156 tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
157 return RunTPUBridge(module, enable_logging, CreateTPUBridgePipeline);
158 }
TPUBridgeV1Compat(ModuleOp module,bool enable_logging)159 tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) {
160 return RunTPUBridge(module, enable_logging, CreateTPUBridgePipelineV1);
161 }
162
163 } // namespace TFTPU
164
165 namespace TF {
166
RunBridgeWithStandardPipeline(ModuleOp module,bool enable_logging,bool enable_inliner)167 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
168 bool enable_logging,
169 bool enable_inliner) {
170 PassManager bridge(module.getContext());
171 if (enable_logging || VLOG_IS_ON(1)) {
172 tensorflow::DumpMlirOpToFile("standard_pipeline_before", module);
173 if (VLOG_IS_ON(2)) EnableLogging(&bridge);
174 }
175
176 StandardPipelineOptions pipeline_options;
177 pipeline_options.enable_inliner.setValue(enable_inliner);
178 CreateTFStandardPipeline(bridge, pipeline_options);
179 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
180 LogicalResult result = bridge.run(module);
181 (void)result;
182 if (enable_logging || VLOG_IS_ON(1))
183 tensorflow::DumpMlirOpToFile("standard_pipeline_after", module);
184 return diag_handler.ConsumeStatus();
185 }
186
187 } // namespace TF
188 } // namespace mlir
189