• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
17 
18 #include <memory>
19 
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/Pass/PassManager.h"  // from @llvm-project
22 #include "mlir/Transforms/Passes.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
24 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
27 
28 namespace mlir {
29 namespace {
30 // Add logger to bridge passmanager.
EnableLogging(PassManager * pm)31 void EnableLogging(PassManager *pm) {
32   // Print the whole module after each pass, which requires disabling
33   // multi-threading as well.
34   pm->getContext()->disableMultithreading();
35   pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
36       /*print_module_scope=*/true));
37   pm->enableTiming(std::make_unique<tensorflow::BridgeTimingConfig>());
38 }
39 }  // namespace
40 
41 namespace TFTPU {
42 namespace {
AddGraphExportLoweringPasses(OpPassManager & pm)43 void AddGraphExportLoweringPasses(OpPassManager &pm) {
44   auto add_pass = [&](std::unique_ptr<Pass> pass) {
45     pm.addNestedPass<FuncOp>(std::move(pass));
46     pm.addPass(CreateBreakUpIslandsPass());
47   };
48 
49   add_pass(CreateFunctionalToExecutorDialectConversionPass());
50   add_pass(TFDevice::CreateReplicateToIslandPass());
51   add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
52   add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
53   pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass());
54   pm.addPass(createSymbolDCEPass());
55 }
56 
RunTPUBridge(ModuleOp module,bool enable_logging,llvm::function_ref<void (OpPassManager & pm)> pipeline_builder)57 tensorflow::Status RunTPUBridge(
58     ModuleOp module, bool enable_logging,
59     llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
60   PassManager bridge(module.getContext());
61   ::tensorflow::applyTensorflowAndCLOptions(bridge);
62   if (enable_logging || VLOG_IS_ON(1)) {
63     tensorflow::DumpMlirOpToFile("tpu_bridge_before", module);
64     if (VLOG_IS_ON(2)) EnableLogging(&bridge);
65   }
66 
67   // Populate a passmanager with the list of passes that implement the bridge.
68   pipeline_builder(bridge);
69 
70   // Add set of passes to lower back to graph (from tf_executor).
71   AddGraphExportLoweringPasses(bridge);
72 
73   // Run the bridge on the module, in case of failure, the `diag_handler`
74   // converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
75   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
76   LogicalResult result = bridge.run(module);
77   (void)result;
78   if (enable_logging || VLOG_IS_ON(1))
79     tensorflow::DumpMlirOpToFile("tpu_bridge_after", module);
80   return diag_handler.ConsumeStatus();
81 }
82 }  // namespace
83 
CreateTPUBridgePipeline(OpPassManager & pm)84 void CreateTPUBridgePipeline(OpPassManager &pm) {
85   // The following ops must be preserved regardless of reachability. Ideally,
86   // all graphs should have control dependencies to enforce this but this is
87   // currently not the case (see b/177478741).
88   const llvm::SmallVector<std::string, 4> ops_to_preserve = {
89       "tf.TPUReplicateMetadata", "tf.TPUCompilationResult",
90       "tf.TPUReplicatedInput", "tf.TPUReplicatedOutput"};
91   pm.addNestedPass<FuncOp>(
92       tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
93   // It is assumed at this stage there are no V1 control flow ops as Graph
94   // functionalization is ran before import. Ops can be lifted out of
95   // tf_executor dialect islands/graphs.
96   pm.addNestedPass<FuncOp>(CreateExecutorDialectToFunctionalConversionPass());
97   // Run shape inference so that tf_executor/tf_device ops created later will
98   // likely to inherit more concrete types.
99   pm.addPass(TF::CreateTFShapeInferencePass());
100   pm.addNestedPass<FuncOp>(CreateTPUReorderReplicateAndPartitionedInputsPass());
101   // Encode this in its own scope so that func_pm is not mistakenly used
102   // later on.
103   {
104     pm.addPass(CreateTPUClusterFormationPass());
105     OpPassManager &func_pm = pm.nest<FuncOp>();
106     // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
107     // because DecomposeResourceOpsPass uses pattern rewriter which hoists
108     // changed constants out of tf_device.Launch.
109     func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
110     func_pm.addPass(CreateTPUHostComputationExpansionPass());
111     func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
112   }
113   // Run another shape inference pass because resource decomposition might have
114   // created new partial types.
115   pm.addPass(TF::CreateTFShapeInferencePass());
116   pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
117   pm.addPass(mlir::createInlinerPass());
118   pm.addPass(CreateTPUClusterCleanupAttributesPass());
119   pm.addPass(TFDevice::CreateResourceOpLiftingPass());
120   pm.addNestedPass<FuncOp>(createCSEPass());
121   pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
122   pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
123   pm.addPass(CreateTPUExtractOutsideCompilationPass());
124 
125   pm.addNestedPass<FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
126   pm.addPass(TF::CreateResourceDeviceInferencePass());
127   pm.addPass(TFDevice::CreateClusterOutliningPass());
128   pm.addPass(CreateTPUDynamicPaddingMapperPass());
129   pm.addPass(CreateTPUResourceReadForWritePass());
130   pm.addPass(CreateTPUShardingIdentificationPass());
131   pm.addNestedPass<FuncOp>(CreateTPUResourceReadsWritesPartitioningPass());
132   pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
133   pm.addPass(CreateTPURewritePass());
134   pm.addPass(createSymbolDCEPass());
135   pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
136   pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
137   pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
138   pm.addPass(CreateTPUVariableReformattingPass());
139   pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
140 }
141 
CreateTPUBridgePipelineV1(OpPassManager & pm)142 void CreateTPUBridgePipelineV1(OpPassManager &pm) {
143   pm.addPass(TF::CreateTFShapeInferencePass());
144   // For V1 compatibility, we process a module where the graph does not have
145   // feeds and fetched. We extract first the TPU computation in a submodule,
146   // where it'll be in a function with args and returned values, much more like
147   // a TF v2 module. We can then run the usual pipeline on this nested module.
148   // Afterward we inline back in the parent module and delete the nested one.
149   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
150   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
151   OpPassManager &nested_module = pm.nest<ModuleOp>();
152   CreateTPUBridgePipeline(nested_module);
153   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
154 }
155 
TPUBridge(ModuleOp module,bool enable_logging)156 tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
157   return RunTPUBridge(module, enable_logging, CreateTPUBridgePipeline);
158 }
TPUBridgeV1Compat(ModuleOp module,bool enable_logging)159 tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) {
160   return RunTPUBridge(module, enable_logging, CreateTPUBridgePipelineV1);
161 }
162 
163 }  // namespace TFTPU
164 
165 namespace TF {
166 
RunBridgeWithStandardPipeline(ModuleOp module,bool enable_logging,bool enable_inliner)167 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
168                                                  bool enable_logging,
169                                                  bool enable_inliner) {
170   PassManager bridge(module.getContext());
171   if (enable_logging || VLOG_IS_ON(1)) {
172     tensorflow::DumpMlirOpToFile("standard_pipeline_before", module);
173     if (VLOG_IS_ON(2)) EnableLogging(&bridge);
174   }
175 
176   StandardPipelineOptions pipeline_options;
177   pipeline_options.enable_inliner.setValue(enable_inliner);
178   CreateTFStandardPipeline(bridge, pipeline_options);
179   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
180   LogicalResult result = bridge.run(module);
181   (void)result;
182   if (enable_logging || VLOG_IS_ON(1))
183     tensorflow::DumpMlirOpToFile("standard_pipeline_after", module);
184   return diag_handler.ConsumeStatus();
185 }
186 
187 }  // namespace TF
188 }  // namespace mlir
189