• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
17 
18 #include <memory>
19 
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/Pass/PassManager.h"  // from @llvm-project
22 #include "mlir/Transforms/Passes.h"  // from @llvm-project
23 #include "tensorflow/compiler/jit/flags.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
28 
29 namespace mlir {
30 namespace {
31 // Add logger to bridge passmanager.
32 // Enable timing statistics per pass for the bridge passmanager.
EnableDetailedLogging(PassManager * pm)33 void EnableDetailedLogging(PassManager *pm) {
34   // Print the whole module after each pass, which requires disabling
35   // multi-threading as well.
36   pm->getContext()->disableMultithreading();
37   pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
38       /*print_module_scope=*/true));
39   pm->enableTiming();
40 }
41 }  // namespace
42 
43 namespace TFTPU {
44 
45 namespace {
RunTPUBridge(ModuleOp module,bool enable_logging,llvm::function_ref<void (OpPassManager & pm)> pipeline_builder)46 tensorflow::Status RunTPUBridge(
47     ModuleOp module, bool enable_logging,
48     llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
49   PassManager bridge(module.getContext());
50   ::tensorflow::applyTensorflowAndCLOptions(bridge);
51   if (enable_logging || VLOG_IS_ON(1)) {
52     tensorflow::DumpMlirOpToFile("tpu_bridge_before", module);
53     if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
54   }
55 
56   // Populate a passmanager with the list of passes that implement the bridge.
57   pipeline_builder(bridge);
58 
59   // Add set of passes to lower back to graph (from tf_executor).
60   TF::AddGraphExportLoweringPasses(bridge);
61 
62   // Run the bridge on the module, in case of failure, the `diag_handler`
63   // converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
64   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
65   LogicalResult result = bridge.run(module);
66   (void)result;
67   if (enable_logging || VLOG_IS_ON(1))
68     tensorflow::DumpMlirOpToFile("tpu_bridge_after", module);
69   return diag_handler.ConsumeStatus();
70 }
71 }  // namespace
72 
CreateTPUBridgePipeline(OpPassManager & pm)73 void CreateTPUBridgePipeline(OpPassManager &pm) {
74   // The following ops must be preserved regardless of reachability. Ideally,
75   // all graphs should have control dependencies to enforce this but this is
76   // currently not the case (see b/177478741).
77   const llvm::SmallVector<std::string, 4> ops_to_preserve = {
78       "tf.TPUReplicateMetadata", "tf.TPUCompilationResult",
79       "tf.TPUReplicatedOutput"};
80   pm.addNestedPass<FuncOp>(
81       tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
82   // It is assumed at this stage there are no V1 control flow ops as Graph
83   // functionalization is ran before import. Ops can be lifted out of
84   // tf_executor dialect islands/graphs.
85   pm.addNestedPass<FuncOp>(CreateExecutorDialectToFunctionalConversionPass());
86   // Guarantee all functions have one use, which enables more exact shape
87   // inference.
88   pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
89   // Run shape inference so that tf_executor/tf_device ops created later will
90   // likely to inherit more concrete types.
91   pm.addPass(TF::CreateTFShapeInferencePass());
92   pm.addNestedPass<FuncOp>(CreateTPUReorderReplicateAndPartitionedInputsPass());
93   pm.addPass(CreateTPUClusterFormationPass());
94   pm.addPass(CreateOutsideCompiledToHostLaunchPass());
95   pm.addNestedPass<FuncOp>(TFDevice::CreateDeviceAttributeToLaunchPass());
96   // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
97   // because DecomposeResourceOpsPass uses pattern rewriter which hoists
98   // changed constants out of tf_device.Launch.
99   pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
100   // Encode this in its own scope so that func_pm is not mistakenly used
101   // later on.
102   {
103     OpPassManager &func_pm = pm.nest<FuncOp>();
104     func_pm.addPass(CreateTPUHostComputationExpansionPass());
105     func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
106   }
107   // TODO(b/173622615): This should incrementally be moved down as
108   // more passes support this representation and then can be removed once
109   // all passes support it.
110   pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
111 
112   // TODO(b/173622615): Once OutsideCompilation is represented by launch op and
113   // the remaining passes including Inliner support it, remove this
114   // LaunchToDeviceAttributePass. This LaunchToDeviceAttribute pass needs to
115   // come before TPUClusterCleanupAttributes pass or else the device attribute
116   // will be removed from launch causing an error.
117   pm.addNestedPass<FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
118 
119   // Note that the region-based control-flow produced here still contains
120   // function call ops which get inlined by the subsequent inliner pass.
121   pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
122   pm.addPass(CreateOutsideCompiledToHostLaunchPass());
123   pm.addPass(mlir::createInlinerPass());
124   pm.addNestedPass<FuncOp>(
125       TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
126   // Run another shape inference pass because resource decomposition might have
127   // created new partial types. Also, after dropping `shape_invariant` attribute
128   // from While/WhileRegion ops within cluster would lead to more precise
129   // shapes.
130   pm.addPass(TF::CreateTFShapeInferencePass());
131   pm.addNestedPass<FuncOp>(createCanonicalizerPass());
132   pm.addPass(CreateTPUClusterCleanupAttributesPass());
133   // TODO(b/173622615): This should incrementally be moved down as
134   // more passes support this representation and then can be removed once
135   // all passes support it.
136   pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
137   pm.addPass(TFDevice::CreateResourceOpLiftingPass());
138   // Re-run the canonicalizer pass as some cleanup during resource op lifting
139   // pass opens up some opportunities for canonicalization of cluster ops.
140   // Specifically, we want to eliminate pass through results from the cluster
141   // op.
142   pm.addNestedPass<FuncOp>(createCanonicalizerPass());
143   pm.addNestedPass<FuncOp>(createCSEPass());
144   if (tensorflow::GetMlirCommonFlags()
145           ->tf_mlir_enable_merge_control_flow_pass) {
146     pm.addPass(TFDevice::CreateMergeControlFlowPass());
147   }
148 
149   pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
150   pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
151   pm.addPass(CreateTPUExtractOutsideCompilationPass());
152 
153   pm.addNestedPass<FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
154   pm.addPass(TF::CreateResourceDeviceInferencePass());
155   pm.addPass(TFDevice::CreateClusterOutliningPass());
156   pm.addPass(CreateTPUResourceReadForWritePass());
157   pm.addPass(CreateTPUShardingIdentificationPass());
158   pm.addNestedPass<FuncOp>(CreateTPUResourceReadsWritesPartitioningPass());
159   pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
160   pm.addPass(TFDevice::CreateMarkInputOutputAliasesPass());
161   pm.addPass(CreateTPURewritePass());
162   pm.addPass(createSymbolDCEPass());
163   pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
164   pm.addPass(CreateTPUMergeVariablesWithExecutePass());
165   pm.addNestedPass<FuncOp>(
166       TF::CreateHoistReplicateInvariantResourceWritesPass());
167   pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
168   pm.addPass(CreateTPUVariableReformattingPass());
169   pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
170 }
171 
CreateTPUBridgePipelineV1(OpPassManager & pm)172 void CreateTPUBridgePipelineV1(OpPassManager &pm) {
173   // Guarantee all functions have one use, which enables more exact shape
174   // inference.
175   pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
176   pm.addPass(TF::CreateTFShapeInferencePass());
177   // For V1 compatibility, we process a module where the graph does not have
178   // feeds and fetched. We extract first the TPU computation in a submodule,
179   // where it'll be in a function with args and returned values, much more like
180   // a TF v2 module. We can then run the usual pipeline on this nested module.
181   // Afterward we inline back in the parent module and delete the nested one.
182   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
183   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
184   OpPassManager &nested_module = pm.nest<ModuleOp>();
185   CreateTPUBridgePipeline(nested_module);
186   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
187 }
188 
TPUBridge(ModuleOp module,bool enable_logging)189 tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
190   return RunTPUBridge(module, enable_logging, CreateTPUBridgePipeline);
191 }
TPUBridgeV1Compat(ModuleOp module,bool enable_logging)192 tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) {
193   return RunTPUBridge(module, enable_logging, CreateTPUBridgePipelineV1);
194 }
195 
196 }  // namespace TFTPU
197 
198 namespace TF {
199 
AddGraphExportLoweringPasses(OpPassManager & pm)200 void AddGraphExportLoweringPasses(OpPassManager &pm) {
201   auto add_pass = [&](std::unique_ptr<Pass> pass) {
202     pm.addNestedPass<FuncOp>(std::move(pass));
203     pm.addPass(CreateBreakUpIslandsPass());
204   };
205 
206   add_pass(CreateFunctionalToExecutorDialectConversionPass());
207   add_pass(TFDevice::CreateReplicateToIslandPass());
208   add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
209   add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
210   pm.addNestedPass<FuncOp>(TFTPU::CreateTPUDevicePropagationPass());
211   pm.addPass(createSymbolDCEPass());
212   pm.addPass(CreateVerifySuitableForExportPass());
213 }
214 
RunBridgeWithStandardPipeline(ModuleOp module,bool enable_logging,bool enable_inliner)215 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
216                                                  bool enable_logging,
217                                                  bool enable_inliner) {
218   PassManager bridge(module.getContext());
219   if (enable_logging || VLOG_IS_ON(1)) {
220     tensorflow::DumpMlirOpToFile("standard_pipeline_before", module);
221     if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
222   }
223 
224   StandardPipelineOptions pipeline_options;
225   pipeline_options.enable_inliner.setValue(enable_inliner);
226   CreateTFStandardPipeline(bridge, pipeline_options);
227   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
228   LogicalResult result = bridge.run(module);
229   (void)result;
230   if (enable_logging || VLOG_IS_ON(1))
231     tensorflow::DumpMlirOpToFile("standard_pipeline_after", module);
232   return diag_handler.ConsumeStatus();
233 }
234 
235 }  // namespace TF
236 }  // namespace mlir
237