• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
17 
18 #include <memory>
19 
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/Pass/PassManager.h"  // from @llvm-project
23 #include "mlir/Transforms/Passes.h"  // from @llvm-project
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
29 #include "tensorflow/core/framework/metrics.h"
30 #include "tensorflow/core/platform/error_payloads.h"
31 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
32 
33 namespace mlir {
34 namespace {
35 
36 // Add logger to bridge passmanager.
37 // Enable timing statistics per pass for the bridge passmanager.
EnableDetailedLogging(PassManager * pm)38 void EnableDetailedLogging(PassManager *pm) {
39   // Print the whole module after each pass, which requires disabling
40   // multi-threading as well.
41   pm->getContext()->disableMultithreading();
42   pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
43       /*print_module_scope=*/true));
44   pm->enableTiming();
45 }
46 }  // namespace
47 
48 namespace TFTPU {
49 
50 namespace {
51 // Run the TF XLA Bridge based on the input pipeline, which can be either TPU
52 // bridge pipeline or non TPU bridge pipeline.
RunTFXLABridge(ModuleOp module,bool enable_logging,llvm::function_ref<void (OpPassManager & pm)> pipeline_builder)53 tensorflow::Status RunTFXLABridge(
54     ModuleOp module, bool enable_logging,
55     llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
56   PassManager bridge(module.getContext());
57   ::tensorflow::applyTensorflowAndCLOptions(bridge);
58 
59   // Populate a passmanager with the list of passes that implement the bridge.
60   pipeline_builder(bridge);
61 
62   // Add set of passes to lower back to graph (from tf_executor).
63   TF::AddGraphExportLoweringPasses(bridge);
64 
65   mlir::StatusScopedDiagnosticHandler diag_handler(
66       module.getContext(), /*propagate=*/false,
67       /*filter_stack=*/!VLOG_IS_ON(1));
68 
69   if (enable_logging || VLOG_IS_ON(1)) {
70     tensorflow::DumpMlirOpToFile("tf_xla_bridge_before", module, "", &bridge);
71     if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
72   }
73   LogicalResult result = bridge.run(module);
74   (void)result;
75   if (enable_logging || VLOG_IS_ON(1))
76     tensorflow::DumpMlirOpToFile("tf_xla_bridge_after", module, "", &bridge);
77   return diag_handler.ConsumeStatus();
78 }
79 
CreateTPUBridgePipelineImpl(OpPassManager & pm)80 void CreateTPUBridgePipelineImpl(OpPassManager &pm) {
81   // The following ops must be preserved regardless of reachability. Ideally,
82   // all graphs should have control dependencies to enforce this but this is
83   // currently not the case (see b/177478741).
84   const llvm::SmallVector<std::string, 4> ops_to_preserve = {
85       "tf.TPUReplicateMetadata", "tf.TPUCompilationResult",
86       "tf.TPUReplicatedOutput"};
87   pm.addNestedPass<func::FuncOp>(
88       tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
89   // It is assumed at this stage there are no V1 control flow ops as Graph
90   // functionalization is ran before import. Ops can be lifted out of
91   // tf_executor dialect islands/graphs.
92   pm.addNestedPass<func::FuncOp>(
93       CreateExecutorDialectToFunctionalConversionPass());
94   // Guarantee all functions have one use, which enables more exact shape
95   // inference.
96   pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
97   // Run shape inference so that tf_executor/tf_device ops created later will
98   // likely to inherit more concrete types.
99   pm.addPass(TF::CreateTFShapeInferencePass());
100   pm.addNestedPass<func::FuncOp>(
101       CreateTPUReorderReplicateAndPartitionedInputsPass());
102   pm.addNestedPass<func::FuncOp>(TF::CreateDecomposeReduceDatasetPass());
103   pm.addPass(CreateTPUClusterFormationPass());
104   // Run TPU cluster cleanup attributes so ops with no outside compiled
105   // attribute have no host device attribute.
106   pm.addPass(CreateTPUClusterCleanupAttributesPass());
107   pm.addPass(CreateOutsideCompiledToHostLaunchPass());
108   pm.addNestedPass<func::FuncOp>(TFDevice::CreateDeviceAttributeToLaunchPass());
109   // Running canonicalizer before decomposing resource ops in cluster helps the
110   // latter pass to converge faster as it does not have to spend time folding
111   // away dead ops.
112   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
113   // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
114   // because DecomposeResourceOpsPass uses pattern rewriter which hoists
115   // changed constants out of tf_device.Launch.
116   pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
117   // Encode this in its own scope so that func_pm is not mistakenly used
118   // later on.
119   {
120     OpPassManager &func_pm = pm.nest<func::FuncOp>();
121     func_pm.addPass(CreateTPUHostComputationExpansionPass());
122     func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
123   }
124   // TODO(b/173622615): This should incrementally be moved down as
125   // more passes support this representation and then can be removed once
126   // all passes support it.
127   pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
128 
129   // TODO(b/173622615): Once OutsideCompilation is represented by launch op and
130   // the remaining passes including Inliner support it, remove this
131   // LaunchToDeviceAttributePass. This LaunchToDeviceAttribute pass needs to
132   // come before TPUClusterCleanupAttributes pass or else the device attribute
133   // will be removed from launch causing an error.
134   pm.addNestedPass<func::FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
135 
136   // TODO(b/173622615): This can be removed once more passes support outside
137   // compilation represented by op and conversion back to attribute is removed.
138   pm.addPass(CreateOutsideCompiledToHostLaunchPass());
139   // Note that the region-based control-flow produced here still contains
140   // function call ops which get inlined by the subsequent inliner pass.
141   pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
142   pm.addPass(mlir::createInlinerPass());
143   pm.addNestedPass<func::FuncOp>(
144       TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
145   // Run another shape inference pass because resource decomposition might have
146   // created new partial types. Also, after dropping `shape_invariant` attribute
147   // from While/WhileRegion ops within cluster would lead to more precise
148   // shapes.
149   pm.addPass(TF::CreateTFShapeInferencePass());
150   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
151   pm.addPass(CreateTPUClusterCleanupAttributesPass());
152   pm.addPass(TFDevice::CreateResourceOpLiftingPass());
153   // Re-run the canonicalizer pass as some cleanup during resource op lifting
154   // pass opens up some opportunities for canonicalization of cluster ops.
155   // Specifically, we want to eliminate pass through results from the cluster
156   // op.
157   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
158 
159   // TODO(b/173622615): This should incrementally be moved down as
160   // more passes support this representation and then can be removed once
161   // all passes support it.
162   pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
163   pm.addNestedPass<func::FuncOp>(createCSEPass());
164   if (tensorflow::GetMlirCommonFlags()
165           ->tf_mlir_enable_merge_control_flow_pass) {
166     pm.addPass(TFDevice::CreateMergeControlFlowPass());
167   }
168 
169   pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
170   pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
171   pm.addPass(CreateTPUExtractOutsideCompilationPass());
172 
173   pm.addNestedPass<func::FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
174   pm.addPass(TF::CreateResourceDeviceInferencePass());
175   pm.addPass(TFDevice::CreateClusterOutliningPass());
176   pm.addPass(CreateTPUResourceReadForWritePass());
177   pm.addPass(TFDevice::CreateMarkInputOutputAliasesPass());
178   pm.addPass(CreateTPUShardingIdentificationPass());
179   pm.addNestedPass<func::FuncOp>(
180       CreateTPUResourceReadsWritesPartitioningPass());
181   pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
182   pm.addPass(CreateTPURewritePass());
183   pm.addPass(createSymbolDCEPass());
184   pm.addNestedPass<func::FuncOp>(
185       TFDevice::CreateReplicateInvariantOpHoistingPass());
186   pm.addPass(CreateTPUMergeVariablesWithExecutePass());
187   pm.addNestedPass<func::FuncOp>(
188       TF::CreateHoistReplicateInvariantResourceWritesPass());
189   pm.addNestedPass<func::FuncOp>(CreateTPUColocateCompositeResourceOps());
190   pm.addPass(CreateTPUVariableRuntimeReformattingPass());
191   pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
192 }
193 }  // namespace
194 
CreateTPUBridgePipeline(OpPassManager & pm)195 void CreateTPUBridgePipeline(OpPassManager &pm) {
196   pm.addNestedPass<func::FuncOp>(
197       CreateCanonicalizeCompileAndReplicateAttributesPass());
198   CreateTPUBridgePipelineImpl(pm);
199 }
200 
CreateTPUBridgePipelineV1(OpPassManager & pm)201 void CreateTPUBridgePipelineV1(OpPassManager &pm) {
202   // Convert to unified compilation and replication attributes.
203   pm.addNestedPass<func::FuncOp>(
204       CreateCanonicalizeCompileAndReplicateAttributesPass());
205   // Guarantee all functions have one use, which enables more exact shape
206   // inference.
207   pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
208   pm.addPass(TF::CreateTFShapeInferencePass());
209   // For V1 compatibility, we process a module where the graph does not have
210   // feeds and fetched. We extract first the TPU computation in a submodule,
211   // where it'll be in a function with args and returned values, much more like
212   // a TF v2 module. We can then run the usual pipeline on this nested module.
213   // Afterward we inline back in the parent module and delete the nested one.
214   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
215   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
216   OpPassManager &nested_module = pm.nest<ModuleOp>();
217   CreateTPUBridgePipelineImpl(nested_module);
218   pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
219   // There are cases where we don't consume all compilation and replication
220   // attributes like we do for the V2 pipeline, so we need to convert them from
221   // unified to legacy attributes before they get exposed to outside of the
222   // bridge.
223   pm.addNestedPass<func::FuncOp>(
224       CreateConvertToLegacyCompileAndReplicateAttributesPass());
225 }
226 
TPUBridge(ModuleOp module,bool enable_logging,bool fallback_enabled)227 tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging,
228                              bool fallback_enabled) {
229   Status status =
230       RunTFXLABridge(module, enable_logging, CreateTPUBridgePipeline);
231   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
232       "tpu", "v2", fallback_enabled,
233       status == ::tensorflow::OkStatus() ? "success" : "failure");
234   OkOrSetErrorCounterPayload(
235       tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1,
236       status);
237   return status;
238 }
TPUBridgeV1Compat(ModuleOp module,bool enable_logging,bool fallback_enabled)239 tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging,
240                                      bool fallback_enabled) {
241   Status status =
242       RunTFXLABridge(module, enable_logging, CreateTPUBridgePipelineV1);
243   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
244       "tpu", "v1", fallback_enabled,
245       status == ::tensorflow::OkStatus() ? "success" : "failure");
246   return status;
247 }
248 
249 }  // namespace TFTPU
250 
251 namespace TF {
252 
AddGraphExportLoweringPasses(OpPassManager & pm)253 void AddGraphExportLoweringPasses(OpPassManager &pm) {
254   auto add_pass = [&](std::unique_ptr<Pass> pass) {
255     pm.addNestedPass<func::FuncOp>(std::move(pass));
256     pm.addPass(CreateBreakUpIslandsPass());
257   };
258 
259   add_pass(CreateFunctionalToExecutorDialectConversionPass());
260   add_pass(TFDevice::CreateReplicateToIslandPass());
261   add_pass(TFDevice::CreateReplicaIDToDeviceOrdinalPass());
262   add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
263   add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
264   pm.addNestedPass<func::FuncOp>(TFTPU::CreateTPUDevicePropagationPass());
265   pm.addPass(createSymbolDCEPass());
266   if (tensorflow::GetMlirCommonFlags()
267           ->tf_mlir_enable_convert_control_to_data_outputs_pass) {
268     pm.addPass(tf_executor::CreateTFExecutorConvertControlToDataOutputsPass());
269   }
270   pm.addPass(CreateVerifySuitableForExportPass());
271 }
272 
RunBridgeWithStandardPipeline(ModuleOp module,bool enable_logging,bool enable_inliner)273 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
274                                                  bool enable_logging,
275                                                  bool enable_inliner) {
276   PassManager bridge(module.getContext());
277 
278   StandardPipelineOptions pipeline_options;
279   pipeline_options.enable_inliner.setValue(enable_inliner);
280   CreateTFStandardPipeline(bridge, pipeline_options);
281 
282   mlir::StatusScopedDiagnosticHandler diag_handler(
283       module.getContext(), /*propagate=*/false,
284       /*filter_stack=*/!VLOG_IS_ON(1));
285 
286   if (enable_logging || VLOG_IS_ON(1)) {
287     tensorflow::DumpMlirOpToFile("standard_pipeline_before", module, "",
288                                  &bridge);
289     if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
290   }
291   LogicalResult result = bridge.run(module);
292   (void)result;
293   if (enable_logging || VLOG_IS_ON(1))
294     tensorflow::DumpMlirOpToFile("standard_pipeline_after", module, "",
295                                  &bridge);
296   return diag_handler.ConsumeStatus();
297 }
298 
299 namespace {
CreateTFXLABridgePipeline(OpPassManager & pm)300 void CreateTFXLABridgePipeline(OpPassManager &pm) {
301   // The following ops must be preserved regardless of reachability. Ideally,
302   // all graphs should have control dependencies to enforce this.
303   VLOG(2) << "Create TF XLA Bridge pipeline";
304   const llvm::SmallVector<std::string, 4> ops_to_preserve = {};
305   pm.addNestedPass<func::FuncOp>(
306       tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
307   // It is assumed at this stage there are no V1 control flow ops as Graph
308   // functionalization is ran before import. Ops can be lifted out of
309   // tf_executor dialect islands/graphs.
310   pm.addNestedPass<func::FuncOp>(
311       CreateExecutorDialectToFunctionalConversionPass());
312   // Guarantee all functions have one use, which enables more exact shape
313   // inference.
314   pm.addPass(TF::CreateTFShapeInferencePass());
315   // Encapsulate PartitionedCall ops within a cluster so that the composite
316   // resource ops can be decomposed.
317   pm.addPass(TFDevice::CreateXlaClusterFormationPass());
318   // Running canonicalizer before decomposing resource ops in cluster helps the
319   // latter pass to converge faster as it does not have to spend time folding
320   // away dead ops.
321   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
322   // Decompose resource ops.
323   pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
324   // Run another shape inference pass because resource decomposition might have
325   // created new partial types. Also, after dropping `shape_invariant` attribute
326   // from While/WhileRegion ops within cluster would lead to more precise
327   // shapes.
328   pm.addPass(TF::CreateTFShapeInferencePass());
329   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
330   pm.addPass(TFDevice::CreateResourceOpLiftingPass());
331   // Inline the StatefulPartitionedCallOp op based in the parent region.
332   pm.addPass(TFDevice::CreateXlaInlineDeviceOpsPass());
333   // Re-run the canonicalizer pass as some cleanup during resource op lifting
334   // pass opens up some opportunities for canonicalization of cluster ops.
335   // Specifically, we want to eliminate pass through results from the cluster
336   // op.
337   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
338 
339   pm.addNestedPass<func::FuncOp>(createCSEPass());
340   pm.addPass(createSymbolDCEPass());
341   pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
342 }
343 
344 }  // namespace
345 
RunTFXLABridge(ModuleOp module,bool enable_logging)346 tensorflow::Status RunTFXLABridge(ModuleOp module, bool enable_logging) {
347   Status status = mlir::TFTPU::RunTFXLABridge(module, enable_logging,
348                                               CreateTFXLABridgePipeline);
349   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
350       /*device type*/ "cpu/gpu", /*bridge version*/ "tfxla",
351       /*fallback_enabled*/ false,
352       /*result*/ status == ::tensorflow::OkStatus() ? "success" : "failure");
353   return status;
354 }
355 
356 }  // namespace TF
357 }  // namespace mlir
358