1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/Block.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/Operation.h" // from @llvm-project
22 #include "mlir/IR/Value.h" // from @llvm-project
23 #include "mlir/IR/Visitors.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29
30 namespace mlir {
31 namespace TFTPU {
32
33 // This pass expands outside compilation attributes to Identity/Cast ops
34 // at the head of TPU computation if it's only used by outside compiled ops.
35
36 namespace {
37
38 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
39
HasOutsideCompilationAttribute(Operation * op)40 bool HasOutsideCompilationAttribute(Operation* op) {
41 return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
42 }
43
44 // Finds op that created a given value. If the value is a BlockArgument, this
45 // returns the owner of the Block.
GetOpOfValue(Value value)46 Operation* GetOpOfValue(Value value) {
47 if (auto block_arg = value.dyn_cast<BlockArgument>())
48 return block_arg.getOwner()->getParentOp();
49
50 return value.getDefiningOp();
51 }
52
53 // TODO(b/158596585): Replace this with a cost model analysis.
IsTrivialUnaryOperation(Operation * op)54 bool IsTrivialUnaryOperation(Operation* op) {
55 return llvm::isa<TF::CastOp, TF::IdentityOp>(op);
56 }
57
58 // Adds outside compilation attributes to unary ops such as Identity/Cast ops
59 // at the head of TPU computation that is used only by other outside compiled
60 // ops. Identity ops and Cast ops is commonly added to the start of TPU
61 // computation. Adding/expanding outside compilation attributes to these ops
62 // will ensure that head outside compiled ops are correctly located and moved to
63 // host.
64 // TODO(b/158691733): Also handle ops inside function calls/control flows.
ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,OpBuilder * builder)65 void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
66 OpBuilder* builder) {
67 Region* cluster_region = &cluster.body();
68 llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
69
70 // Traverse the graph in topological order to find all outside compiled ops
71 // at head of TPU computation or unary ops that are only used by other outside
72 // compiled ops.
73 auto cluster_ops = cluster.GetBody().without_terminator();
74 for (Operation& cluster_op : cluster_ops) {
75 if (IsTrivialUnaryOperation(&cluster_op) ||
76 HasOutsideCompilationAttribute(&cluster_op)) {
77 auto walk_result = cluster_op.walk([&](Operation* op) {
78 for (Value operand : op->getOperands()) {
79 Operation* operand_op = GetOpOfValue(operand);
80 if (head_outside_compiled_ops.count(operand_op)) continue;
81
82 if (operand_op->getParentRegion() == cluster_region)
83 return WalkResult::interrupt();
84 }
85 return WalkResult::advance();
86 });
87
88 if (!walk_result.wasInterrupted())
89 head_outside_compiled_ops.insert(&cluster_op);
90 }
91 }
92
93 for (auto head_outside_compiled_op :
94 llvm::reverse(head_outside_compiled_ops)) {
95 auto users = head_outside_compiled_op->getUsers();
96 if (users.empty() ||
97 HasOutsideCompilationAttribute(head_outside_compiled_op))
98 continue;
99
100 bool should_expand_op_to_host_computation = true;
101 for (auto consumer_op : users) {
102 if (should_expand_op_to_host_computation &&
103 !HasOutsideCompilationAttribute(consumer_op)) {
104 should_expand_op_to_host_computation = false;
105 continue;
106 }
107 }
108
109 if (should_expand_op_to_host_computation)
110 head_outside_compiled_op->setAttr(kXlaOutsideCompilationAttr,
111 builder->getStringAttr(""));
112 }
113 }
114
115 struct TPUHostComputationExpansion
116 : public PassWrapper<TPUHostComputationExpansion, FunctionPass> {
117 void runOnFunction() override;
118 };
119
runOnFunction()120 void TPUHostComputationExpansion::runOnFunction() {
121 OpBuilder builder(&getContext());
122 getFunction().walk([&](tf_device::ClusterOp cluster) {
123 ExpandHeadOutsideCompiledOps(cluster, &builder);
124 });
125 }
126
127 } // anonymous namespace
128
CreateTPUHostComputationExpansionPass()129 std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass() {
130 return std::make_unique<TPUHostComputationExpansion>();
131 }
132
133 static PassRegistration<TPUHostComputationExpansion> pass(
134 "tf-tpu-host-computation-expansion",
135 "Expands host computation before and after TPU computation.");
136
137 } // namespace TFTPU
138 } // namespace mlir
139