1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
18 
19 #include <memory>
20 
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
26 
27 namespace mlir {
28 
29 // Creates a pass that breaks up an island with multiple ops into multiple
30 // islands, each with a single op.
31 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass();
32 
33 // Creates a pass that converts mlir functions consisting of mlir ops into a
34 // tf_executor dialect as a single island.
35 std::unique_ptr<OperationPass<func::FuncOp>>
36 CreateFunctionalToExecutorDialectConversionPass();
37 
38 // Creates a pass that lifts inner ops of tf_executor.island ops in
39 // tf_executor.graph into the same block as the tf_executor.graph.
40 std::unique_ptr<OperationPass<func::FuncOp>>
41 CreateExecutorDialectToFunctionalConversionPass();
42 
43 namespace TF {
44 // Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
45 // ops.
46 std::unique_ptr<OperationPass<func::FuncOp>>
47 CreateDropWhileShapeInvariantPass();
48 
49 // Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
50 // ops within device cluster.
51 std::unique_ptr<OperationPass<func::FuncOp>>
52 CreateDropWhileShapeInvariantInDeviceClusterPass();
53 
54 // Creates a pass that moves writes to replicate invariant resource variables
55 // outside tf_device.replicate op.
56 std::unique_ptr<OperationPass<func::FuncOp>>
57 CreateHoistReplicateInvariantResourceWritesPass();
58 
59 // Transforms functional control flow operations in the TensorFlow dialect to
60 // MLIR Control Flow Graph (CFG) form.
61 std::unique_ptr<OperationPass<func::FuncOp>>
62 CreateTFFunctionalControlFlowToCFG();
63 
64 // Transforms functional control flow operations in the TensorFlow dialect to
65 // their region based counterparts.
66 std::unique_ptr<OperationPass<ModuleOp>>
67 CreateTFFunctionalControlFlowToRegions();
68 
69 // Transforms region bases control flow operations in the TensorFlow dialect to
70 // their functional counterparts.
71 std::unique_ptr<OperationPass<ModuleOp>>
72 CreateTFRegionControlFlowToFunctional();
73 
74 // Materialize the MlirPassthroughOp by replacing it with the MLIR module
75 // attached as an attribute.
76 std::unique_ptr<OperationPass<func::FuncOp>>
77 CreateMaterializePassthroughOpPass();
78 
79 // Performs Shape Inference on the TensorFlow dialect using the global registry.
80 std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass();
81 
82 // Performs TF.data optimizations.
83 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFDataOptimizationPass();
84 
85 std::unique_ptr<OperationPass<func::FuncOp>> CreateMoveTransposesPass();
86 std::unique_ptr<OperationPass<func::FuncOp>> CreateLayoutAssignmentPass();
87 
88 // Guarantee that all FuncOp's have a single use.
89 std::unique_ptr<OperationPass<ModuleOp>> CreateGuaranteeAllFuncsOneUsePass();
90 
91 // Optional pass which will unroll BatchMatMul and use only MatMul
92 std::unique_ptr<OperationPass<func::FuncOp>> CreateUnrollBatchMatMulPassPass();
93 
94 // Optional pass which will map TF BatchMatMul to TF Einsum
95 std::unique_ptr<OperationPass<func::FuncOp>> CreateBatchMatMulToEinsumPass();
96 
97 // Pass that transform Einsum to other TF Ops for the supported variants.
98 std::unique_ptr<OperationPass<func::FuncOp>> CreateTransformEinsumPass();
99 
100 // Optimizes Tensorflow graph.
101 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFOptimizePass();
102 void RegisterTFOptimizePassPipeline();
103 
104 // Creates pass to rewrite RecvTPUEmbeddingActivationsOp and
105 // SendTPUEmbeddingGradients ops to internal variants.
106 std::unique_ptr<OperationPass<func::FuncOp>> CreateRewriteTPUEmbeddingOpsPass();
107 
108 // Performs specific fusion for GPU targets.
109 std::unique_ptr<OperationPass<func::FuncOp>> CreateGpuOpFusionPass();
110 
111 // Creates a pass that decomposes to be compiled ReduceDataset ops into a while
112 // loop that iterates the dataset and calls the reduction function.
113 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeReduceDatasetPass();
114 
115 // Create a pass that convert ops that copy tensors between devices, e.g.
116 // tf.Identity.
117 std::unique_ptr<OperationPass<mlir::func::FuncOp>>
118 CreateTensorDeviceCopyConversionPass();
119 
120 // Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they
121 // have built in broadcasting support.
122 std::unique_ptr<OperationPass<func::FuncOp>> CreateBroadcastFoldPass();
123 
124 void populateTfControlFlowToScfPatterns(MLIRContext* context,
125                                         RewritePatternSet* patterns);
126 // Create a pass to convert TensorFlow control flow to SCF.
127 std::unique_ptr<OperationPass<ModuleOp>> createConvertTfControlFlowToScfPass();
128 
129 struct LayoutOptimizationPipelineOptions
130     : public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
131   Option<std::string> force_data_format{
132       *this, "force-data-format",
133       llvm::cl::desc("Force data format for all layout sensitive ops")};
134   Option<bool> skip_fold_transpose_in_ops{
135       *this, "skip-fold-transpose-in-ops",
136       llvm::cl::desc("Skip folding transpose operands in Ops which can support "
137                      "different layouts.")};
138 };
139 
140 // Layout optimization assigns optimal data layout for layout sensitive
141 // operations, and cancels all redundant transposes.
142 void CreateLayoutOptimizationPipeline(
143     OpPassManager& pm,  // NOLINT - MLIR contract is pass by mutable reference.
144     const LayoutOptimizationPipelineOptions& options);
145 
146 struct StandardPipelineOptions
147     : public PassPipelineOptions<StandardPipelineOptions> {
148   Option<bool> enable_inliner{*this, "enable-inliner",
149                               llvm::cl::desc("Enable inliner."),
150                               llvm::cl::init(false)};
151   Option<bool> form_clusters{*this, "form-clusters",
152                              llvm::cl::desc("Enable Cluster Formation pass."),
153                              llvm::cl::init(false)};
154 };
155 
156 // Propagates the pass manager with the passes involved in transforming or
157 // optimizing an MLIR graph without any target specialization.
158 // NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
159 void CreateTFStandardPipeline(OpPassManager& pm,
160                               const StandardPipelineOptions& options);
161 
162 // Propagates device attributes of resources from callers to callees.
163 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
164 
165 // Creates a pass that promotes resource reads/writes in `functions` to inputs
166 // and outputs of `functions`, assuming that resource operations have already
167 // been decomposed and function calls have already been inlined. If `functions`
168 // is empty, the pass is applied to the main function by default. The pass also
169 // annotates the input arguments for resources with the indices of their
170 // aliasing output arguments.
171 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass(
172     llvm::ArrayRef<std::string> functions = {});
173 
174 // Creates a pass that promotes tf.VarHandleOp to resource arguments for all
175 // functions.
176 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
177 
178 // Creates a pass that converts readonly reference variables to the
179 // corresponding resource variables.
180 std::unique_ptr<OperationPass<func::FuncOp>>
181 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
182 
183 // Creates a simple device assignment pass on TF dialect for CoreRT use case.
184 std::unique_ptr<OperationPass<func::FuncOp>> CreateSimpleTFDeviceAssignmentPass(
185     llvm::StringRef default_device = "cpu");
186 
187 // Creates a pass to perform device assignment for TF dialect ops that do not
188 // have device assignment, by using the device attribute of the function.
189 std::unique_ptr<OperationPass<func::FuncOp>>
190 CreateTFDeviceAssignmentByFuncAttrPass();
191 
192 // Performs resource lifting on the function body to hoist resource variable
193 // accesses outside all control flow statements.
194 LogicalResult ResourceLiftingForFunctionalControlFlow(func::FuncOp function);
195 
196 // Converts stack ops into operations on local variables, which can later be
197 // removed by resource lifting. Requires known maximum sizes of stacks and
198 // known element shapes of push ops.
199 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass();
200 
201 // Creates a pass to strip the "tf._noinline" attribute from the functions in
202 // the module.
203 std::unique_ptr<OperationPass<ModuleOp>> CreateStripNoinlineAttributePass();
204 
205 // Converts tensor list operations into operations on buffers and sizes. Needs
206 // static shapes and known max element count.
207 std::unique_ptr<OperationPass<ModuleOp>> CreateTensorListOpsDecompositionPass();
208 
209 // Converts tensor array ops into operations on local variables, which can later
210 // be removed by resource lifting. Requires known sizes and known element shapes
211 // (either defined in TensorArrayV3 or implied in the first write).
212 std::unique_ptr<OperationPass<ModuleOp>>
213 CreateTensorArrayOpsDecompositionPass();
214 
215 // Create a pass that legalize HLO to TF dialect.
216 std::unique_ptr<OperationPass<func::FuncOp>> CreateLegalizeHloToTfPass();
217 
218 // Create a pass that legalize TFG to TF dialect.
219 std::unique_ptr<Pass> CreateLegalizeTFGToTFEPass();
220 
221 // Addds the HLO to TF rewrite patterns to the specified pattern list.
222 void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns,
223                                      MLIRContext* context);
224 
225 // Matches sequence of ops to TensorFlow fused kernels. This pass should not be
226 // generally used beyond exporting to runtimes that supports these ops. In the
227 // future these fusions may be codegen'd automatically.
228 std::unique_ptr<OperationPass<func::FuncOp>> CreateFusedKernelMatcherPass();
229 
230 // Creates function pass to select device index/fold tf.DeviceIndex.
231 std::unique_ptr<OperationPass<func::FuncOp>> CreateDeviceIndexSelectorPass();
232 
233 // Creates function pass to replace InitializeTableFromTextFileV2Ops with
234 // LookupTableImportV2Op ops.
235 std::unique_ptr<OperationPass<func::FuncOp>> CreateInitTextFileToImportPass(
236     std::string saved_model_dir = "");
237 
238 // Creates function pass to cluster TensorFlow ops by host. The program
239 // generated by this pass will have one function per host where all operations
240 // in the same function are placed on the same host. Each result of the per-host
241 // function will have a "tf.device" attribute which specifies the device
242 // assignment of the result.
243 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateClusterTFOpsByHostPass();
244 
245 // Creates a pass to insert tf_device.send and tf_device.receive ops to make
246 // sure any argument of any op is on the same host of the op itself.
247 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
248 
249 // Creates a pass that adds the device attribute to every tf.Const op based on
250 // the device attribute of the operations that read its result. If the result of
251 // a tf.Const op is read by operations placed on multiple devices, then the pass
252 // will replicate the tf.Const op once for each device.
253 std::unique_ptr<OperationPass<ModuleOp>> CreateConstantOpDeviceAssignmentPass();
254 
255 // Populates the supplied passmanager with the passes required to export
256 // to TensorFlow Graph.
257 void AddGraphExportLoweringPasses(OpPassManager& pm);
258 
259 // Returns pass that verifies whether all functions in module are of single
260 // tf_executor.graph and each tf_executor.island in tf_executor.graph only has a
261 // single op.
262 std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass();
263 
264 // Returns pass that prepares TPU computation to be legal for export to
265 // TensorFlow.
266 std::unique_ptr<OperationPass<ModuleOp>>
267 CreatePrepareTpuComputationForTfExportPass();
268 
269 // Rewrites ops that require quantized inputs or outputs to ops that allow
270 // non-quantized inputs and outputs.
271 std::unique_ptr<OperationPass<func::FuncOp>> CreateLowerQuantizedPass();
272 
273 // Reorders ops so ops of the same dialect are next to each other.
274 std::unique_ptr<Pass> CreateOrderByDialectPass();
275 
276 // Groups ops into functions that only contain one dialect.
277 std::unique_ptr<Pass> CreateGroupByDialectPass();
278 }  // namespace TF
279 
280 namespace tf_executor {
281 
282 // Creates a pass to chain control outputs of while loop body.
283 std::unique_ptr<OperationPass<ModuleOp>>
284 CreateTFExecutorConvertControlToDataOutputsPass();
285 
286 // Creates a pass to merge IslandOps from TFExecutor dialect.
287 std::unique_ptr<OperationPass<func::FuncOp>>
288 CreateTFExecutorIslandCoarseningPass();
289 
290 // Creates a pass to merge IslandOps for operation marked for execution on TPU.
291 // This is a V1 backward compatibility.
292 std::unique_ptr<OperationPass<ModuleOp>>
293 CreateTFExecutorTPUV1IslandCoarseningPass();
294 
295 // Creates a pass to outlining TPU clusters from single IslandOp into a nested
296 // module suitable for being processed as-if it was a V2 module.
297 // This is a V1 backward compatibility.
298 std::unique_ptr<OperationPass<ModuleOp>>
299 CreateTFExecutorTPUV1IslandOutliningPass();
300 
301 // Creates a pass to inline calls to the nested TPU module, this reverses the
302 // effect of the `TFExecutorTPUV1IslandOutlining` pass above.
303 // This is a V1 backward compatibility.
304 std::unique_ptr<OperationPass<ModuleOp>>
305 CreateTFExecutorTPUV1IslandInliningPass();
306 
307 // Creates a pass to prune tf_executor.graph from dead nodes.
308 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFExecutorGraphPruningPass(
309     llvm::ArrayRef<std::string> ops_to_preserve = {});
310 }  // namespace tf_executor
311 
312 namespace TFDevice {
313 // Creates a pass that forms clusters from instructions that are assigned to
314 // same device.
315 std::unique_ptr<OperationPass<func::FuncOp>> CreateClusterFormationPass();
316 
317 // Sinks `tf.Const` operations in the ClusterOp region using them. This is
318 // performed in order to limit the number of values implicitly captured in this
319 // region before outlining.
320 std::unique_ptr<OperationPass<func::FuncOp>> CreateClusterConstantSinkingPass(
321     llvm::function_ref<bool(tf_device::ClusterOp, ElementsAttr)> filter = {});
322 
323 // Creates a pass that outlines regions of tf_device.cluster operations.
324 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass();
325 
326 // Creates a pass that outlines regions of tf_device.launch operations.
327 std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchOutliningPass();
328 
329 // Creates a pass that converts tf_device::LaunchFuncOp into
330 // TF::PartitionedCallOp.
331 std::unique_ptr<OperationPass<ModuleOp>> CreateConvertLaunchFuncToTFCallPass();
332 
333 // A pass that decomposes composite resource operations into primitive ones like
334 // ReadVariableOp, AssignVariableOp and other computations to facilitate
335 // transformations like resource op lifting.
336 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeResourceOpsPass();
337 
338 // A pass that decomposes composite resource operations in device cluster
339 // (tf_device.cluster op) into primitive ones like ReadVariableOp,
340 // AssignVariableOp and other computations to facilitate transformations like
341 // resource op lifting.
342 std::unique_ptr<OperationPass<ModuleOp>>
343 CreateDecomposeResourceOpsInClusterPass();
344 
345 // Creates a pass that marks TPU cluster input-output pairs reading and writing
346 // to same resource variable as aliases.
347 std::unique_ptr<OperationPass<ModuleOp>> CreateMarkInputOutputAliasesPass();
348 
349 // Creates a pass that lifts operations on external resource variables from
350 // device computation nested in `tf_device::LaunchOp` out so that resource
351 // variable load operations are all before device computation while resource
352 // variable store operations are all after device computation. After this pass,
353 // device computation no longer interacts with external resource variables.
354 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass();
355 
356 // Creates a pass that lifts operations from the main function.
357 std::unique_ptr<OperationPass<ModuleOp>>
358 CreateResourceOpLiftingForMainFunctionPass();
359 
360 // Lifts resource operations from tf_device.launch_func ops nested in `op`
361 // outside. Returns a failure if there are remaining resource-type values that
362 // can not be lifted.
363 LogicalResult LiftResourceOps(Operation* op);
364 
365 // Creates a pass that hoists invariant operations in a `tf_device.replicate`.
366 std::unique_ptr<OperationPass<func::FuncOp>>
367 CreateReplicateInvariantOpHoistingPass();
368 
369 // Creates a pass that forms replica `tf_executor.island` from a single
370 // `tf_device.replicate` island.
371 std::unique_ptr<OperationPass<func::FuncOp>> CreateReplicateToIslandPass();
372 
373 // Creates a pass that sets the device ordinal attribute of the required op
374 // using the replica id attribute.
375 std::unique_ptr<OperationPass<func::FuncOp>>
376 CreateReplicaIDToDeviceOrdinalPass();
377 
378 // Creates a pass that creates `tf_executor.island` from a single
379 // `tf_device.parallel_execute` island.
380 std::unique_ptr<OperationPass<func::FuncOp>>
381 CreateParallelExecuteToIslandsPass();
382 
383 // Creates a pass that annotates whether a LaunchFuncOp's parameters have the
384 // same data across replicas.
385 std::unique_ptr<OperationPass<ModuleOp>>
386 CreateAnnotateParameterReplicationPass();
387 
388 // Creates a pass that marks unsupported ops in device cluster for outside
389 // compilation.
390 std::unique_ptr<OperationPass<ModuleOp>>
391 CreateMarkOpsForOutsideCompilationPass();
392 
393 // Creates a pass that merges control flow with similar predicates.
394 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass();
395 
396 // Creates a pass that wraps each TensorFlow dialect with `device` attribute
397 // in a `tf_device.launch` op with the same `device` attribute.
398 std::unique_ptr<OperationPass<func::FuncOp>>
399 CreateDeviceAttributeToLaunchPass();
400 
401 // Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
402 // attribute to each TensorFlow dialect op in the body based on the `device`
403 // attribute on the `tf_device.launch`.
404 std::unique_ptr<OperationPass<func::FuncOp>>
405 CreateLaunchToDeviceAttributePass();
406 
407 // Creates a pass that extracts ops in tf_device.launch op with host device
408 // assignment and adds an `_xla_outside_compilation` attribute value.
409 std::unique_ptr<OperationPass<ModuleOp>>
410 CreateHostLaunchToOutsideCompiledPass();
411 
412 // Create a pass that encapsulates StatefulPartitionedCallOp within a cluster.
413 std::unique_ptr<OperationPass<ModuleOp>> CreateXlaClusterFormationPass();
414 
415 // Create a pass that inlines the StatefulPartitionedCallOp op based in the
416 // parent region.
417 std::unique_ptr<OperationPass<ModuleOp>> CreateXlaInlineDeviceOpsPass();
418 
419 }  // namespace TFDevice
420 
421 namespace TFTPU {
422 // Creates a pass that canonicalizes legacy compilation and replication
423 // attributes.
424 std::unique_ptr<OperationPass<func::FuncOp>>
425 CreateCanonicalizeCompileAndReplicateAttributesPass();
426 
427 // Creates a pass that converts unified compilation and replication
428 // attributes back to legacy attributes.
429 std::unique_ptr<OperationPass<func::FuncOp>>
430 CreateConvertToLegacyCompileAndReplicateAttributesPass();
431 
432 // Creates a pass that forms clusters from operations of the same
433 // `_replication_info` attribute.
434 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
435 
436 // Creates a pass that cleans up `_replication_info` attribute on operations
437 // that are inside a cluster.
438 std::unique_ptr<OperationPass<ModuleOp>>
439 CreateTPUClusterCleanupAttributesPass();
440 
441 // Creates a pass that removes Identity/IdentityN ops from a cluster.
442 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
443 
444 // Creates a pass that allows TPU program inputs to have layouts determined at
445 // run time.
446 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
447 
448 // Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources
449 // the cluster only writes to.
450 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass();
451 
452 // Creates a pass that reorders partitiioned resource reads and replicated
453 // inputs.
454 std::unique_ptr<OperationPass<func::FuncOp>>
455 CreateTPUReorderReplicateAndPartitionedInputsPass();
456 
457 // Creates a pass that partitions unpartitioned resource read/write to
458 // partitioned resource variables.
459 std::unique_ptr<OperationPass<func::FuncOp>>
460 CreateTPUResourceReadsWritesPartitioningPass();
461 
462 // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
463 // ops.
464 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass();
465 
466 // Creates a pass that identifies XLASharding ops in launch op for TPU
467 // computation.
468 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass();
469 
470 // Creates a pass that moves `tf.AssignVariableOp` into a
471 // `tf_device.parallel_execute` region if the `tf.AssignVariableOp` is the
472 // only consumer of a `tf_device.parallel_execute` result.
473 std::unique_ptr<OperationPass<func::FuncOp>>
474 CreateTPUParallelExecuteSinkResourceWritePass();
475 
476 // Creates a pass that merges device variable reads/updates into the surrounded
477 // TPUExecute node. This allows the execute node to perform in-place variable
478 // updates.
479 std::unique_ptr<OperationPass<ModuleOp>>
480 CreateTPUMergeVariablesWithExecutePass();
481 
482 // Creates a pass that wraps ReadVariableOp/AssignVariable op that consumes a
483 // packed tensor to have same device placement as underlying TPU device.
484 std::unique_ptr<OperationPass<func::FuncOp>>
485 CreateTPUColocateCompositeResourceOps();
486 
487 // Creates a pass that adds ops which perform formatting on variables at
488 // run-time according to compilation result.
489 std::unique_ptr<OperationPass<ModuleOp>>
490 CreateTPUVariableRuntimeReformattingPass();
491 
492 // Creates a pass that wraps ops with the same `_xla_outside_compilation`
493 // attribute value in a tf_device.launch op with host device assignment.
494 std::unique_ptr<OperationPass<ModuleOp>>
495 CreateOutsideCompiledToHostLaunchPass();
496 
497 // Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
498 // at head/tail of TPU cluster to run before/after TPU computation.
499 std::unique_ptr<OperationPass<ModuleOp>>
500 CreateTPUExtractHeadTailOutsideCompilationPass();
501 
502 // Creates a pass that expands outside compilation cluster at the head/tail of
503 // TPU computation by adding outside compilation attribute to identity/cast ops
504 // that are only used for host computation.
505 std::unique_ptr<OperationPass<func::FuncOp>>
506 CreateTPUHostComputationExpansionPass();
507 
508 // Creates a pass that updates inputs to TPU embedding layer enqueue ops so that
509 // correct ops are invoked during training and evaluation.
510 std::unique_ptr<OperationPass<func::FuncOp>>
511 CreateTPUUpdateEmbeddingEnqueueOpInputsPass();
512 
513 // Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
514 // ops to a separate parallel_execute region to run on CPU.
515 std::unique_ptr<OperationPass<ModuleOp>>
516 CreateTPUExtractOutsideCompilationPass();
517 
518 // Creates a pass that propagates TPU devices to users.
519 std::unique_ptr<OperationPass<func::FuncOp>> CreateTPUDevicePropagationPass();
520 
521 // Populates the supplied passmanager with the passes required to run the
522 // bridge.
523 void CreateTPUBridgePipeline(OpPassManager& pm);
524 
525 // Populates the supplied passmanager with the passes required to run the
526 // bridge in V1 mode.
527 void CreateTPUBridgePipelineV1(OpPassManager& pm);
528 
529 // Creates a pass that replicates the tf._TPUCompileMlir op on each host that
530 // needs the compiled program. It helps avoid transferring the compiled binary
531 // between hosts.
532 std::unique_ptr<OperationPass<mlir::ModuleOp>>
533 CreateTPUCompileOpReplicationPass();
534 
535 // Creates a pass that applies space to depth transform
536 // for the first or frontier convolutions consume host inputs on TPU.
537 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass();
538 
539 }  // namespace TFTPU
540 
541 // Define the registrations in a detail namespace, just so that we can overload
542 // the main entry point `registerTensorFlowPasses` to inject
543 // RegisterTFOptimizePassPipeline.
544 namespace detail {
545 #define GEN_PASS_REGISTRATION
546 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
547 }  // namespace detail
548 using namespace detail;  // NOLINT
registerTensorFlowPasses()549 inline void registerTensorFlowPasses() {
550   detail::registerTensorFlowPasses();
551   TF::RegisterTFOptimizePassPipeline();
552 }
553 
554 namespace TFDevice {
555 #define GEN_PASS_REGISTRATION
556 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc"
557 }  // namespace TFDevice
558 
559 }  // namespace mlir
560 
561 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
562