1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
17
18 #include <memory>
19
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/Pass/PassManager.h" // from @llvm-project
23 #include "mlir/Transforms/Passes.h" // from @llvm-project
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
29 #include "tensorflow/core/framework/metrics.h"
30 #include "tensorflow/core/platform/error_payloads.h"
31 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
32
33 namespace mlir {
34 namespace {
35
36 // Add logger to bridge passmanager.
37 // Enable timing statistics per pass for the bridge passmanager.
EnableDetailedLogging(PassManager * pm)38 void EnableDetailedLogging(PassManager *pm) {
39 // Print the whole module after each pass, which requires disabling
40 // multi-threading as well.
41 pm->getContext()->disableMultithreading();
42 pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
43 /*print_module_scope=*/true));
44 pm->enableTiming();
45 }
46 } // namespace
47
48 namespace TFTPU {
49
50 namespace {
51 // Run the TF XLA Bridge based on the input pipeline, which can be either TPU
52 // bridge pipeline or non TPU bridge pipeline.
RunTFXLABridge(ModuleOp module,bool enable_logging,llvm::function_ref<void (OpPassManager & pm)> pipeline_builder)53 tensorflow::Status RunTFXLABridge(
54 ModuleOp module, bool enable_logging,
55 llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
56 PassManager bridge(module.getContext());
57 ::tensorflow::applyTensorflowAndCLOptions(bridge);
58
59 // Populate a passmanager with the list of passes that implement the bridge.
60 pipeline_builder(bridge);
61
62 // Add set of passes to lower back to graph (from tf_executor).
63 TF::AddGraphExportLoweringPasses(bridge);
64
65 mlir::StatusScopedDiagnosticHandler diag_handler(
66 module.getContext(), /*propagate=*/false,
67 /*filter_stack=*/!VLOG_IS_ON(1));
68
69 if (enable_logging || VLOG_IS_ON(1)) {
70 tensorflow::DumpMlirOpToFile("tf_xla_bridge_before", module, "", &bridge);
71 if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
72 }
73 LogicalResult result = bridge.run(module);
74 (void)result;
75 if (enable_logging || VLOG_IS_ON(1))
76 tensorflow::DumpMlirOpToFile("tf_xla_bridge_after", module, "", &bridge);
77 return diag_handler.ConsumeStatus();
78 }
79
CreateTPUBridgePipelineImpl(OpPassManager & pm)80 void CreateTPUBridgePipelineImpl(OpPassManager &pm) {
81 // The following ops must be preserved regardless of reachability. Ideally,
82 // all graphs should have control dependencies to enforce this but this is
83 // currently not the case (see b/177478741).
84 const llvm::SmallVector<std::string, 4> ops_to_preserve = {
85 "tf.TPUReplicateMetadata", "tf.TPUCompilationResult",
86 "tf.TPUReplicatedOutput"};
87 pm.addNestedPass<func::FuncOp>(
88 tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
89 // It is assumed at this stage there are no V1 control flow ops as Graph
90 // functionalization is ran before import. Ops can be lifted out of
91 // tf_executor dialect islands/graphs.
92 pm.addNestedPass<func::FuncOp>(
93 CreateExecutorDialectToFunctionalConversionPass());
94 // Guarantee all functions have one use, which enables more exact shape
95 // inference.
96 pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
97 // Run shape inference so that tf_executor/tf_device ops created later will
98 // likely to inherit more concrete types.
99 pm.addPass(TF::CreateTFShapeInferencePass());
100 pm.addNestedPass<func::FuncOp>(
101 CreateTPUReorderReplicateAndPartitionedInputsPass());
102 pm.addNestedPass<func::FuncOp>(TF::CreateDecomposeReduceDatasetPass());
103 pm.addPass(CreateTPUClusterFormationPass());
104 // Run TPU cluster cleanup attributes so ops with no outside compiled
105 // attribute have no host device attribute.
106 pm.addPass(CreateTPUClusterCleanupAttributesPass());
107 pm.addPass(CreateOutsideCompiledToHostLaunchPass());
108 pm.addNestedPass<func::FuncOp>(TFDevice::CreateDeviceAttributeToLaunchPass());
109 // Running canonicalizer before decomposing resource ops in cluster helps the
110 // latter pass to converge faster as it does not have to spend time folding
111 // away dead ops.
112 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
113 // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
114 // because DecomposeResourceOpsPass uses pattern rewriter which hoists
115 // changed constants out of tf_device.Launch.
116 pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
117 // Encode this in its own scope so that func_pm is not mistakenly used
118 // later on.
119 {
120 OpPassManager &func_pm = pm.nest<func::FuncOp>();
121 func_pm.addPass(CreateTPUHostComputationExpansionPass());
122 func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
123 }
124 // TODO(b/173622615): This should incrementally be moved down as
125 // more passes support this representation and then can be removed once
126 // all passes support it.
127 pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
128
129 // TODO(b/173622615): Once OutsideCompilation is represented by launch op and
130 // the remaining passes including Inliner support it, remove this
131 // LaunchToDeviceAttributePass. This LaunchToDeviceAttribute pass needs to
132 // come before TPUClusterCleanupAttributes pass or else the device attribute
133 // will be removed from launch causing an error.
134 pm.addNestedPass<func::FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
135
136 // TODO(b/173622615): This can be removed once more passes support outside
137 // compilation represented by op and conversion back to attribute is removed.
138 pm.addPass(CreateOutsideCompiledToHostLaunchPass());
139 // Note that the region-based control-flow produced here still contains
140 // function call ops which get inlined by the subsequent inliner pass.
141 pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
142 pm.addPass(mlir::createInlinerPass());
143 pm.addNestedPass<func::FuncOp>(
144 TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
145 // Run another shape inference pass because resource decomposition might have
146 // created new partial types. Also, after dropping `shape_invariant` attribute
147 // from While/WhileRegion ops within cluster would lead to more precise
148 // shapes.
149 pm.addPass(TF::CreateTFShapeInferencePass());
150 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
151 pm.addPass(CreateTPUClusterCleanupAttributesPass());
152 pm.addPass(TFDevice::CreateResourceOpLiftingPass());
153 // Re-run the canonicalizer pass as some cleanup during resource op lifting
154 // pass opens up some opportunities for canonicalization of cluster ops.
155 // Specifically, we want to eliminate pass through results from the cluster
156 // op.
157 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
158
159 // TODO(b/173622615): This should incrementally be moved down as
160 // more passes support this representation and then can be removed once
161 // all passes support it.
162 pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
163 pm.addNestedPass<func::FuncOp>(createCSEPass());
164 if (tensorflow::GetMlirCommonFlags()
165 ->tf_mlir_enable_merge_control_flow_pass) {
166 pm.addPass(TFDevice::CreateMergeControlFlowPass());
167 }
168
169 pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
170 pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
171 pm.addPass(CreateTPUExtractOutsideCompilationPass());
172
173 pm.addNestedPass<func::FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
174 pm.addPass(TF::CreateResourceDeviceInferencePass());
175 pm.addPass(TFDevice::CreateClusterOutliningPass());
176 pm.addPass(CreateTPUResourceReadForWritePass());
177 pm.addPass(TFDevice::CreateMarkInputOutputAliasesPass());
178 pm.addPass(CreateTPUShardingIdentificationPass());
179 pm.addNestedPass<func::FuncOp>(
180 CreateTPUResourceReadsWritesPartitioningPass());
181 pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
182 pm.addPass(CreateTPURewritePass());
183 pm.addPass(createSymbolDCEPass());
184 pm.addNestedPass<func::FuncOp>(
185 TFDevice::CreateReplicateInvariantOpHoistingPass());
186 pm.addPass(CreateTPUMergeVariablesWithExecutePass());
187 pm.addNestedPass<func::FuncOp>(
188 TF::CreateHoistReplicateInvariantResourceWritesPass());
189 pm.addNestedPass<func::FuncOp>(CreateTPUColocateCompositeResourceOps());
190 pm.addPass(CreateTPUVariableRuntimeReformattingPass());
191 pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
192 }
193 } // namespace
194
CreateTPUBridgePipeline(OpPassManager & pm)195 void CreateTPUBridgePipeline(OpPassManager &pm) {
196 pm.addNestedPass<func::FuncOp>(
197 CreateCanonicalizeCompileAndReplicateAttributesPass());
198 CreateTPUBridgePipelineImpl(pm);
199 }
200
CreateTPUBridgePipelineV1(OpPassManager & pm)201 void CreateTPUBridgePipelineV1(OpPassManager &pm) {
202 // Convert to unified compilation and replication attributes.
203 pm.addNestedPass<func::FuncOp>(
204 CreateCanonicalizeCompileAndReplicateAttributesPass());
205 // Guarantee all functions have one use, which enables more exact shape
206 // inference.
207 pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
208 pm.addPass(TF::CreateTFShapeInferencePass());
209 // For V1 compatibility, we process a module where the graph does not have
210 // feeds and fetched. We extract first the TPU computation in a submodule,
211 // where it'll be in a function with args and returned values, much more like
212 // a TF v2 module. We can then run the usual pipeline on this nested module.
213 // Afterward we inline back in the parent module and delete the nested one.
214 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
215 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
216 OpPassManager &nested_module = pm.nest<ModuleOp>();
217 CreateTPUBridgePipelineImpl(nested_module);
218 pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
219 // There are cases where we don't consume all compilation and replication
220 // attributes like we do for the V2 pipeline, so we need to convert them from
221 // unified to legacy attributes before they get exposed to outside of the
222 // bridge.
223 pm.addNestedPass<func::FuncOp>(
224 CreateConvertToLegacyCompileAndReplicateAttributesPass());
225 }
226
TPUBridge(ModuleOp module,bool enable_logging,bool fallback_enabled)227 tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging,
228 bool fallback_enabled) {
229 Status status =
230 RunTFXLABridge(module, enable_logging, CreateTPUBridgePipeline);
231 tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
232 "tpu", "v2", fallback_enabled,
233 status == ::tensorflow::OkStatus() ? "success" : "failure");
234 OkOrSetErrorCounterPayload(
235 tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1,
236 status);
237 return status;
238 }
TPUBridgeV1Compat(ModuleOp module,bool enable_logging,bool fallback_enabled)239 tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging,
240 bool fallback_enabled) {
241 Status status =
242 RunTFXLABridge(module, enable_logging, CreateTPUBridgePipelineV1);
243 tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
244 "tpu", "v1", fallback_enabled,
245 status == ::tensorflow::OkStatus() ? "success" : "failure");
246 return status;
247 }
248
249 } // namespace TFTPU
250
251 namespace TF {
252
AddGraphExportLoweringPasses(OpPassManager & pm)253 void AddGraphExportLoweringPasses(OpPassManager &pm) {
254 auto add_pass = [&](std::unique_ptr<Pass> pass) {
255 pm.addNestedPass<func::FuncOp>(std::move(pass));
256 pm.addPass(CreateBreakUpIslandsPass());
257 };
258
259 add_pass(CreateFunctionalToExecutorDialectConversionPass());
260 add_pass(TFDevice::CreateReplicateToIslandPass());
261 add_pass(TFDevice::CreateReplicaIDToDeviceOrdinalPass());
262 add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
263 add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
264 pm.addNestedPass<func::FuncOp>(TFTPU::CreateTPUDevicePropagationPass());
265 pm.addPass(createSymbolDCEPass());
266 if (tensorflow::GetMlirCommonFlags()
267 ->tf_mlir_enable_convert_control_to_data_outputs_pass) {
268 pm.addPass(tf_executor::CreateTFExecutorConvertControlToDataOutputsPass());
269 }
270 pm.addPass(CreateVerifySuitableForExportPass());
271 }
272
RunBridgeWithStandardPipeline(ModuleOp module,bool enable_logging,bool enable_inliner)273 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
274 bool enable_logging,
275 bool enable_inliner) {
276 PassManager bridge(module.getContext());
277
278 StandardPipelineOptions pipeline_options;
279 pipeline_options.enable_inliner.setValue(enable_inliner);
280 CreateTFStandardPipeline(bridge, pipeline_options);
281
282 mlir::StatusScopedDiagnosticHandler diag_handler(
283 module.getContext(), /*propagate=*/false,
284 /*filter_stack=*/!VLOG_IS_ON(1));
285
286 if (enable_logging || VLOG_IS_ON(1)) {
287 tensorflow::DumpMlirOpToFile("standard_pipeline_before", module, "",
288 &bridge);
289 if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
290 }
291 LogicalResult result = bridge.run(module);
292 (void)result;
293 if (enable_logging || VLOG_IS_ON(1))
294 tensorflow::DumpMlirOpToFile("standard_pipeline_after", module, "",
295 &bridge);
296 return diag_handler.ConsumeStatus();
297 }
298
299 namespace {
CreateTFXLABridgePipeline(OpPassManager & pm)300 void CreateTFXLABridgePipeline(OpPassManager &pm) {
301 // The following ops must be preserved regardless of reachability. Ideally,
302 // all graphs should have control dependencies to enforce this.
303 VLOG(2) << "Create TF XLA Bridge pipeline";
304 const llvm::SmallVector<std::string, 4> ops_to_preserve = {};
305 pm.addNestedPass<func::FuncOp>(
306 tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
307 // It is assumed at this stage there are no V1 control flow ops as Graph
308 // functionalization is ran before import. Ops can be lifted out of
309 // tf_executor dialect islands/graphs.
310 pm.addNestedPass<func::FuncOp>(
311 CreateExecutorDialectToFunctionalConversionPass());
312 // Guarantee all functions have one use, which enables more exact shape
313 // inference.
314 pm.addPass(TF::CreateTFShapeInferencePass());
315 // Encapsulate PartitionedCall ops within a cluster so that the composite
316 // resource ops can be decomposed.
317 pm.addPass(TFDevice::CreateXlaClusterFormationPass());
318 // Running canonicalizer before decomposing resource ops in cluster helps the
319 // latter pass to converge faster as it does not have to spend time folding
320 // away dead ops.
321 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
322 // Decompose resource ops.
323 pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
324 // Run another shape inference pass because resource decomposition might have
325 // created new partial types. Also, after dropping `shape_invariant` attribute
326 // from While/WhileRegion ops within cluster would lead to more precise
327 // shapes.
328 pm.addPass(TF::CreateTFShapeInferencePass());
329 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
330 pm.addPass(TFDevice::CreateResourceOpLiftingPass());
331 // Inline the StatefulPartitionedCallOp op based in the parent region.
332 pm.addPass(TFDevice::CreateXlaInlineDeviceOpsPass());
333 // Re-run the canonicalizer pass as some cleanup during resource op lifting
334 // pass opens up some opportunities for canonicalization of cluster ops.
335 // Specifically, we want to eliminate pass through results from the cluster
336 // op.
337 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
338
339 pm.addNestedPass<func::FuncOp>(createCSEPass());
340 pm.addPass(createSymbolDCEPass());
341 pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
342 }
343
344 } // namespace
345
RunTFXLABridge(ModuleOp module,bool enable_logging)346 tensorflow::Status RunTFXLABridge(ModuleOp module, bool enable_logging) {
347 Status status = mlir::TFTPU::RunTFXLABridge(module, enable_logging,
348 CreateTFXLABridgePipeline);
349 tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
350 /*device type*/ "cpu/gpu", /*bridge version*/ "tfxla",
351 /*fallback_enabled*/ false,
352 /*result*/ status == ::tensorflow::OkStatus() ? "success" : "failure");
353 return status;
354 }
355
356 } // namespace TF
357 } // namespace mlir
358