1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
18
19 #include <memory>
20
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/MLIRContext.h" // from @llvm-project
23 #include "mlir/IR/PatternMatch.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
26
27 namespace mlir {
28
29 // Creates a pass that breaks up an island with multiple ops into multiple
30 // islands, each with a single op.
31 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass();
32
33 // Creates a pass that converts mlir functions consisting of mlir ops into a
34 // tf_executor dialect as a single island.
35 std::unique_ptr<OperationPass<func::FuncOp>>
36 CreateFunctionalToExecutorDialectConversionPass();
37
38 // Creates a pass that lifts inner ops of tf_executor.island ops in
39 // tf_executor.graph into the same block as the tf_executor.graph.
40 std::unique_ptr<OperationPass<func::FuncOp>>
41 CreateExecutorDialectToFunctionalConversionPass();
42
43 namespace TF {
44 // Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
45 // ops.
46 std::unique_ptr<OperationPass<func::FuncOp>>
47 CreateDropWhileShapeInvariantPass();
48
49 // Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
50 // ops within device cluster.
51 std::unique_ptr<OperationPass<func::FuncOp>>
52 CreateDropWhileShapeInvariantInDeviceClusterPass();
53
54 // Creates a pass that moves writes to replicate invariant resource variables
55 // outside tf_device.replicate op.
56 std::unique_ptr<OperationPass<func::FuncOp>>
57 CreateHoistReplicateInvariantResourceWritesPass();
58
59 // Transforms functional control flow operations in the TensorFlow dialect to
60 // MLIR Control Flow Graph (CFG) form.
61 std::unique_ptr<OperationPass<func::FuncOp>>
62 CreateTFFunctionalControlFlowToCFG();
63
64 // Transforms functional control flow operations in the TensorFlow dialect to
65 // their region based counterparts.
66 std::unique_ptr<OperationPass<ModuleOp>>
67 CreateTFFunctionalControlFlowToRegions();
68
69 // Transforms region bases control flow operations in the TensorFlow dialect to
70 // their functional counterparts.
71 std::unique_ptr<OperationPass<ModuleOp>>
72 CreateTFRegionControlFlowToFunctional();
73
74 // Materialize the MlirPassthroughOp by replacing it with the MLIR module
75 // attached as an attribute.
76 std::unique_ptr<OperationPass<func::FuncOp>>
77 CreateMaterializePassthroughOpPass();
78
79 // Performs Shape Inference on the TensorFlow dialect using the global registry.
80 std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass();
81
82 // Performs TF.data optimizations.
83 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFDataOptimizationPass();
84
85 std::unique_ptr<OperationPass<func::FuncOp>> CreateMoveTransposesPass();
86 std::unique_ptr<OperationPass<func::FuncOp>> CreateLayoutAssignmentPass();
87
88 // Guarantee that all FuncOp's have a single use.
89 std::unique_ptr<OperationPass<ModuleOp>> CreateGuaranteeAllFuncsOneUsePass();
90
91 // Optional pass which will unroll BatchMatMul and use only MatMul
92 std::unique_ptr<OperationPass<func::FuncOp>> CreateUnrollBatchMatMulPassPass();
93
94 // Optional pass which will map TF BatchMatMul to TF Einsum
95 std::unique_ptr<OperationPass<func::FuncOp>> CreateBatchMatMulToEinsumPass();
96
97 // Pass that transform Einsum to other TF Ops for the supported variants.
98 std::unique_ptr<OperationPass<func::FuncOp>> CreateTransformEinsumPass();
99
100 // Optimizes Tensorflow graph.
101 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFOptimizePass();
102 void RegisterTFOptimizePassPipeline();
103
104 // Creates pass to rewrite RecvTPUEmbeddingActivationsOp and
105 // SendTPUEmbeddingGradients ops to internal variants.
106 std::unique_ptr<OperationPass<func::FuncOp>> CreateRewriteTPUEmbeddingOpsPass();
107
108 // Performs specific fusion for GPU targets.
109 std::unique_ptr<OperationPass<func::FuncOp>> CreateGpuOpFusionPass();
110
111 // Creates a pass that decomposes to be compiled ReduceDataset ops into a while
112 // loop that iterates the dataset and calls the reduction function.
113 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeReduceDatasetPass();
114
115 // Create a pass that convert ops that copy tensors between devices, e.g.
116 // tf.Identity.
117 std::unique_ptr<OperationPass<mlir::func::FuncOp>>
118 CreateTensorDeviceCopyConversionPass();
119
120 // Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they
121 // have built in broadcasting support.
122 std::unique_ptr<OperationPass<func::FuncOp>> CreateBroadcastFoldPass();
123
124 void populateTfControlFlowToScfPatterns(MLIRContext* context,
125 RewritePatternSet* patterns);
126 // Create a pass to convert TensorFlow control flow to SCF.
127 std::unique_ptr<OperationPass<ModuleOp>> createConvertTfControlFlowToScfPass();
128
129 struct LayoutOptimizationPipelineOptions
130 : public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
131 Option<std::string> force_data_format{
132 *this, "force-data-format",
133 llvm::cl::desc("Force data format for all layout sensitive ops")};
134 Option<bool> skip_fold_transpose_in_ops{
135 *this, "skip-fold-transpose-in-ops",
136 llvm::cl::desc("Skip folding transpose operands in Ops which can support "
137 "different layouts.")};
138 };
139
140 // Layout optimization assigns optimal data layout for layout sensitive
141 // operations, and cancels all redundant transposes.
142 void CreateLayoutOptimizationPipeline(
143 OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference.
144 const LayoutOptimizationPipelineOptions& options);
145
146 struct StandardPipelineOptions
147 : public PassPipelineOptions<StandardPipelineOptions> {
148 Option<bool> enable_inliner{*this, "enable-inliner",
149 llvm::cl::desc("Enable inliner."),
150 llvm::cl::init(false)};
151 Option<bool> form_clusters{*this, "form-clusters",
152 llvm::cl::desc("Enable Cluster Formation pass."),
153 llvm::cl::init(false)};
154 };
155
156 // Propagates the pass manager with the passes involved in transforming or
157 // optimizing an MLIR graph without any target specialization.
158 // NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
159 void CreateTFStandardPipeline(OpPassManager& pm,
160 const StandardPipelineOptions& options);
161
162 // Propagates device attributes of resources from callers to callees.
163 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
164
165 // Creates a pass that promotes resource reads/writes in `functions` to inputs
166 // and outputs of `functions`, assuming that resource operations have already
167 // been decomposed and function calls have already been inlined. If `functions`
168 // is empty, the pass is applied to the main function by default. The pass also
169 // annotates the input arguments for resources with the indices of their
170 // aliasing output arguments.
171 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass(
172 llvm::ArrayRef<std::string> functions = {});
173
174 // Creates a pass that promotes tf.VarHandleOp to resource arguments for all
175 // functions.
176 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
177
178 // Creates a pass that converts readonly reference variables to the
179 // corresponding resource variables.
180 std::unique_ptr<OperationPass<func::FuncOp>>
181 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
182
183 // Creates a simple device assignment pass on TF dialect for CoreRT use case.
184 std::unique_ptr<OperationPass<func::FuncOp>> CreateSimpleTFDeviceAssignmentPass(
185 llvm::StringRef default_device = "cpu");
186
187 // Creates a pass to perform device assignment for TF dialect ops that do not
188 // have device assignment, by using the device attribute of the function.
189 std::unique_ptr<OperationPass<func::FuncOp>>
190 CreateTFDeviceAssignmentByFuncAttrPass();
191
192 // Performs resource lifting on the function body to hoist resource variable
193 // accesses outside all control flow statements.
194 LogicalResult ResourceLiftingForFunctionalControlFlow(func::FuncOp function);
195
196 // Converts stack ops into operations on local variables, which can later be
197 // removed by resource lifting. Requires known maximum sizes of stacks and
198 // known element shapes of push ops.
199 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass();
200
201 // Creates a pass to strip the "tf._noinline" attribute from the functions in
202 // the module.
203 std::unique_ptr<OperationPass<ModuleOp>> CreateStripNoinlineAttributePass();
204
205 // Converts tensor list operations into operations on buffers and sizes. Needs
206 // static shapes and known max element count.
207 std::unique_ptr<OperationPass<ModuleOp>> CreateTensorListOpsDecompositionPass();
208
209 // Converts tensor array ops into operations on local variables, which can later
210 // be removed by resource lifting. Requires known sizes and known element shapes
211 // (either defined in TensorArrayV3 or implied in the first write).
212 std::unique_ptr<OperationPass<ModuleOp>>
213 CreateTensorArrayOpsDecompositionPass();
214
215 // Create a pass that legalize HLO to TF dialect.
216 std::unique_ptr<OperationPass<func::FuncOp>> CreateLegalizeHloToTfPass();
217
218 // Create a pass that legalize TFG to TF dialect.
219 std::unique_ptr<Pass> CreateLegalizeTFGToTFEPass();
220
221 // Addds the HLO to TF rewrite patterns to the specified pattern list.
222 void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns,
223 MLIRContext* context);
224
225 // Matches sequence of ops to TensorFlow fused kernels. This pass should not be
226 // generally used beyond exporting to runtimes that supports these ops. In the
227 // future these fusions may be codegen'd automatically.
228 std::unique_ptr<OperationPass<func::FuncOp>> CreateFusedKernelMatcherPass();
229
230 // Creates function pass to select device index/fold tf.DeviceIndex.
231 std::unique_ptr<OperationPass<func::FuncOp>> CreateDeviceIndexSelectorPass();
232
233 // Creates function pass to replace InitializeTableFromTextFileV2Ops with
234 // LookupTableImportV2Op ops.
235 std::unique_ptr<OperationPass<func::FuncOp>> CreateInitTextFileToImportPass(
236 std::string saved_model_dir = "");
237
238 // Creates function pass to cluster TensorFlow ops by host. The program
239 // generated by this pass will have one function per host where all operations
240 // in the same function are placed on the same host. Each result of the per-host
241 // function will have a "tf.device" attribute which specifies the device
242 // assignment of the result.
243 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateClusterTFOpsByHostPass();
244
245 // Creates a pass to insert tf_device.send and tf_device.receive ops to make
246 // sure any argument of any op is on the same host of the op itself.
247 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
248
249 // Creates a pass that adds the device attribute to every tf.Const op based on
250 // the device attribute of the operations that read its result. If the result of
251 // a tf.Const op is read by operations placed on multiple devices, then the pass
252 // will replicate the tf.Const op once for each device.
253 std::unique_ptr<OperationPass<ModuleOp>> CreateConstantOpDeviceAssignmentPass();
254
255 // Populates the supplied passmanager with the passes required to export
256 // to TensorFlow Graph.
257 void AddGraphExportLoweringPasses(OpPassManager& pm);
258
259 // Returns pass that verifies whether all functions in module are of single
260 // tf_executor.graph and each tf_executor.island in tf_executor.graph only has a
261 // single op.
262 std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass();
263
264 // Returns pass that prepares TPU computation to be legal for export to
265 // TensorFlow.
266 std::unique_ptr<OperationPass<ModuleOp>>
267 CreatePrepareTpuComputationForTfExportPass();
268
269 // Rewrites ops that require quantized inputs or outputs to ops that allow
270 // non-quantized inputs and outputs.
271 std::unique_ptr<OperationPass<func::FuncOp>> CreateLowerQuantizedPass();
272
273 // Reorders ops so ops of the same dialect are next to each other.
274 std::unique_ptr<Pass> CreateOrderByDialectPass();
275
276 // Groups ops into functions that only contain one dialect.
277 std::unique_ptr<Pass> CreateGroupByDialectPass();
278 } // namespace TF
279
280 namespace tf_executor {
281
282 // Creates a pass to chain control outputs of while loop body.
283 std::unique_ptr<OperationPass<ModuleOp>>
284 CreateTFExecutorConvertControlToDataOutputsPass();
285
286 // Creates a pass to merge IslandOps from TFExecutor dialect.
287 std::unique_ptr<OperationPass<func::FuncOp>>
288 CreateTFExecutorIslandCoarseningPass();
289
290 // Creates a pass to merge IslandOps for operation marked for execution on TPU.
291 // This is a V1 backward compatibility.
292 std::unique_ptr<OperationPass<ModuleOp>>
293 CreateTFExecutorTPUV1IslandCoarseningPass();
294
295 // Creates a pass to outlining TPU clusters from single IslandOp into a nested
296 // module suitable for being processed as-if it was a V2 module.
297 // This is a V1 backward compatibility.
298 std::unique_ptr<OperationPass<ModuleOp>>
299 CreateTFExecutorTPUV1IslandOutliningPass();
300
301 // Creates a pass to inline calls to the nested TPU module, this reverses the
302 // effect of the `TFExecutorTPUV1IslandOutlining` pass above.
303 // This is a V1 backward compatibility.
304 std::unique_ptr<OperationPass<ModuleOp>>
305 CreateTFExecutorTPUV1IslandInliningPass();
306
307 // Creates a pass to prune tf_executor.graph from dead nodes.
308 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFExecutorGraphPruningPass(
309 llvm::ArrayRef<std::string> ops_to_preserve = {});
310 } // namespace tf_executor
311
312 namespace TFDevice {
313 // Creates a pass that forms clusters from instructions that are assigned to
314 // same device.
315 std::unique_ptr<OperationPass<func::FuncOp>> CreateClusterFormationPass();
316
317 // Sinks `tf.Const` operations in the ClusterOp region using them. This is
318 // performed in order to limit the number of values implicitly captured in this
319 // region before outlining.
320 std::unique_ptr<OperationPass<func::FuncOp>> CreateClusterConstantSinkingPass(
321 llvm::function_ref<bool(tf_device::ClusterOp, ElementsAttr)> filter = {});
322
323 // Creates a pass that outlines regions of tf_device.cluster operations.
324 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass();
325
326 // Creates a pass that outlines regions of tf_device.launch operations.
327 std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchOutliningPass();
328
329 // Creates a pass that converts tf_device::LaunchFuncOp into
330 // TF::PartitionedCallOp.
331 std::unique_ptr<OperationPass<ModuleOp>> CreateConvertLaunchFuncToTFCallPass();
332
333 // A pass that decomposes composite resource operations into primitive ones like
334 // ReadVariableOp, AssignVariableOp and other computations to facilitate
335 // transformations like resource op lifting.
336 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeResourceOpsPass();
337
338 // A pass that decomposes composite resource operations in device cluster
339 // (tf_device.cluster op) into primitive ones like ReadVariableOp,
340 // AssignVariableOp and other computations to facilitate transformations like
341 // resource op lifting.
342 std::unique_ptr<OperationPass<ModuleOp>>
343 CreateDecomposeResourceOpsInClusterPass();
344
345 // Creates a pass that marks TPU cluster input-output pairs reading and writing
346 // to same resource variable as aliases.
347 std::unique_ptr<OperationPass<ModuleOp>> CreateMarkInputOutputAliasesPass();
348
349 // Creates a pass that lifts operations on external resource variables from
350 // device computation nested in `tf_device::LaunchOp` out so that resource
351 // variable load operations are all before device computation while resource
352 // variable store operations are all after device computation. After this pass,
353 // device computation no longer interacts with external resource variables.
354 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass();
355
356 // Creates a pass that lifts operations from the main function.
357 std::unique_ptr<OperationPass<ModuleOp>>
358 CreateResourceOpLiftingForMainFunctionPass();
359
360 // Lifts resource operations from tf_device.launch_func ops nested in `op`
361 // outside. Returns a failure if there are remaining resource-type values that
362 // can not be lifted.
363 LogicalResult LiftResourceOps(Operation* op);
364
365 // Creates a pass that hoists invariant operations in a `tf_device.replicate`.
366 std::unique_ptr<OperationPass<func::FuncOp>>
367 CreateReplicateInvariantOpHoistingPass();
368
369 // Creates a pass that forms replica `tf_executor.island` from a single
370 // `tf_device.replicate` island.
371 std::unique_ptr<OperationPass<func::FuncOp>> CreateReplicateToIslandPass();
372
373 // Creates a pass that sets the device ordinal attribute of the required op
374 // using the replica id attribute.
375 std::unique_ptr<OperationPass<func::FuncOp>>
376 CreateReplicaIDToDeviceOrdinalPass();
377
378 // Creates a pass that creates `tf_executor.island` from a single
379 // `tf_device.parallel_execute` island.
380 std::unique_ptr<OperationPass<func::FuncOp>>
381 CreateParallelExecuteToIslandsPass();
382
383 // Creates a pass that annotates whether a LaunchFuncOp's parameters have the
384 // same data across replicas.
385 std::unique_ptr<OperationPass<ModuleOp>>
386 CreateAnnotateParameterReplicationPass();
387
388 // Creates a pass that marks unsupported ops in device cluster for outside
389 // compilation.
390 std::unique_ptr<OperationPass<ModuleOp>>
391 CreateMarkOpsForOutsideCompilationPass();
392
393 // Creates a pass that merges control flow with similar predicates.
394 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass();
395
396 // Creates a pass that wraps each TensorFlow dialect with `device` attribute
397 // in a `tf_device.launch` op with the same `device` attribute.
398 std::unique_ptr<OperationPass<func::FuncOp>>
399 CreateDeviceAttributeToLaunchPass();
400
401 // Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
402 // attribute to each TensorFlow dialect op in the body based on the `device`
403 // attribute on the `tf_device.launch`.
404 std::unique_ptr<OperationPass<func::FuncOp>>
405 CreateLaunchToDeviceAttributePass();
406
407 // Creates a pass that extracts ops in tf_device.launch op with host device
408 // assignment and adds an `_xla_outside_compilation` attribute value.
409 std::unique_ptr<OperationPass<ModuleOp>>
410 CreateHostLaunchToOutsideCompiledPass();
411
412 // Create a pass that encapsulates StatefulPartitionedCallOp within a cluster.
413 std::unique_ptr<OperationPass<ModuleOp>> CreateXlaClusterFormationPass();
414
415 // Create a pass that inlines the StatefulPartitionedCallOp op based in the
416 // parent region.
417 std::unique_ptr<OperationPass<ModuleOp>> CreateXlaInlineDeviceOpsPass();
418
419 } // namespace TFDevice
420
421 namespace TFTPU {
422 // Creates a pass that canonicalizes legacy compilation and replication
423 // attributes.
424 std::unique_ptr<OperationPass<func::FuncOp>>
425 CreateCanonicalizeCompileAndReplicateAttributesPass();
426
427 // Creates a pass that converts unified compilation and replication
428 // attributes back to legacy attributes.
429 std::unique_ptr<OperationPass<func::FuncOp>>
430 CreateConvertToLegacyCompileAndReplicateAttributesPass();
431
432 // Creates a pass that forms clusters from operations of the same
433 // `_replication_info` attribute.
434 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
435
436 // Creates a pass that cleans up `_replication_info` attribute on operations
437 // that are inside a cluster.
438 std::unique_ptr<OperationPass<ModuleOp>>
439 CreateTPUClusterCleanupAttributesPass();
440
441 // Creates a pass that removes Identity/IdentityN ops from a cluster.
442 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
443
444 // Creates a pass that allows TPU program inputs to have layouts determined at
445 // run time.
446 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
447
448 // Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources
449 // the cluster only writes to.
450 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass();
451
452 // Creates a pass that reorders partitiioned resource reads and replicated
453 // inputs.
454 std::unique_ptr<OperationPass<func::FuncOp>>
455 CreateTPUReorderReplicateAndPartitionedInputsPass();
456
457 // Creates a pass that partitions unpartitioned resource read/write to
458 // partitioned resource variables.
459 std::unique_ptr<OperationPass<func::FuncOp>>
460 CreateTPUResourceReadsWritesPartitioningPass();
461
462 // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
463 // ops.
464 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass();
465
466 // Creates a pass that identifies XLASharding ops in launch op for TPU
467 // computation.
468 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass();
469
470 // Creates a pass that moves `tf.AssignVariableOp` into a
471 // `tf_device.parallel_execute` region if the `tf.AssignVariableOp` is the
472 // only consumer of a `tf_device.parallel_execute` result.
473 std::unique_ptr<OperationPass<func::FuncOp>>
474 CreateTPUParallelExecuteSinkResourceWritePass();
475
476 // Creates a pass that merges device variable reads/updates into the surrounded
477 // TPUExecute node. This allows the execute node to perform in-place variable
478 // updates.
479 std::unique_ptr<OperationPass<ModuleOp>>
480 CreateTPUMergeVariablesWithExecutePass();
481
482 // Creates a pass that wraps ReadVariableOp/AssignVariable op that consumes a
483 // packed tensor to have same device placement as underlying TPU device.
484 std::unique_ptr<OperationPass<func::FuncOp>>
485 CreateTPUColocateCompositeResourceOps();
486
487 // Creates a pass that adds ops which perform formatting on variables at
488 // run-time according to compilation result.
489 std::unique_ptr<OperationPass<ModuleOp>>
490 CreateTPUVariableRuntimeReformattingPass();
491
492 // Creates a pass that wraps ops with the same `_xla_outside_compilation`
493 // attribute value in a tf_device.launch op with host device assignment.
494 std::unique_ptr<OperationPass<ModuleOp>>
495 CreateOutsideCompiledToHostLaunchPass();
496
497 // Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
498 // at head/tail of TPU cluster to run before/after TPU computation.
499 std::unique_ptr<OperationPass<ModuleOp>>
500 CreateTPUExtractHeadTailOutsideCompilationPass();
501
502 // Creates a pass that expands outside compilation cluster at the head/tail of
503 // TPU computation by adding outside compilation attribute to identity/cast ops
504 // that are only used for host computation.
505 std::unique_ptr<OperationPass<func::FuncOp>>
506 CreateTPUHostComputationExpansionPass();
507
508 // Creates a pass that updates inputs to TPU embedding layer enqueue ops so that
509 // correct ops are invoked during training and evaluation.
510 std::unique_ptr<OperationPass<func::FuncOp>>
511 CreateTPUUpdateEmbeddingEnqueueOpInputsPass();
512
513 // Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
514 // ops to a separate parallel_execute region to run on CPU.
515 std::unique_ptr<OperationPass<ModuleOp>>
516 CreateTPUExtractOutsideCompilationPass();
517
518 // Creates a pass that propagates TPU devices to users.
519 std::unique_ptr<OperationPass<func::FuncOp>> CreateTPUDevicePropagationPass();
520
521 // Populates the supplied passmanager with the passes required to run the
522 // bridge.
523 void CreateTPUBridgePipeline(OpPassManager& pm);
524
525 // Populates the supplied passmanager with the passes required to run the
526 // bridge in V1 mode.
527 void CreateTPUBridgePipelineV1(OpPassManager& pm);
528
529 // Creates a pass that replicates the tf._TPUCompileMlir op on each host that
530 // needs the compiled program. It helps avoid transferring the compiled binary
531 // between hosts.
532 std::unique_ptr<OperationPass<mlir::ModuleOp>>
533 CreateTPUCompileOpReplicationPass();
534
535 // Creates a pass that applies space to depth transform
536 // for the first or frontier convolutions consume host inputs on TPU.
537 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass();
538
539 } // namespace TFTPU
540
541 // Define the registrations in a detail namespace, just so that we can overload
542 // the main entry point `registerTensorFlowPasses` to inject
543 // RegisterTFOptimizePassPipeline.
544 namespace detail {
545 #define GEN_PASS_REGISTRATION
546 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
547 } // namespace detail
548 using namespace detail; // NOLINT
registerTensorFlowPasses()549 inline void registerTensorFlowPasses() {
550 detail::registerTensorFlowPasses();
551 TF::RegisterTFOptimizePassPipeline();
552 }
553
554 namespace TFDevice {
555 #define GEN_PASS_REGISTRATION
556 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc"
557 } // namespace TFDevice
558
559 } // namespace mlir
560
561 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
562