• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Block.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/Operation.h"  // from @llvm-project
22 #include "mlir/IR/Value.h"  // from @llvm-project
23 #include "mlir/IR/Visitors.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 
31 namespace mlir {
32 namespace TFTPU {
33 
34 namespace {
35 
36 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
37 
HasOutsideCompilationAttribute(Operation * op)38 bool HasOutsideCompilationAttribute(Operation* op) {
39   return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
40 }
41 
42 // Finds op that created a given value. If the value is a BlockArgument, this
43 // returns the owner of the Block.
GetOpOfValue(Value value)44 Operation* GetOpOfValue(Value value) {
45   if (auto block_arg = value.dyn_cast<BlockArgument>())
46     return block_arg.getOwner()->getParentOp();
47 
48   return value.getDefiningOp();
49 }
50 
51 // TODO(b/158596585): Replace this with a cost model analysis.
IsTrivialUnaryOperation(Operation * op)52 bool IsTrivialUnaryOperation(Operation* op) {
53   return llvm::isa<TF::CastOp, TF::IdentityOp>(op);
54 }
55 
56 // Adds outside compilation attributes to unary ops such as Identity/Cast ops
57 // at the head of TPU computation that is used only by other outside compiled
58 // ops. Identity ops and Cast ops is commonly added to the start of TPU
59 // computation. Adding/expanding outside compilation attributes to these ops
60 // will ensure that head outside compiled ops are correctly located and moved to
61 // host.
62 // TODO(b/158691733): Also handle ops inside function calls/control flows.
ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,OpBuilder * builder)63 void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
64                                   OpBuilder* builder) {
65   Region* cluster_region = &cluster.body();
66   llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
67 
68   // Traverse the graph in topological order to find all outside compiled ops
69   // at head of TPU computation or unary ops that are only used by other outside
70   // compiled ops.
71   auto cluster_ops = cluster.GetBody().without_terminator();
72   for (Operation& cluster_op : cluster_ops) {
73     if (IsTrivialUnaryOperation(&cluster_op) ||
74         HasOutsideCompilationAttribute(&cluster_op)) {
75       auto walk_result = cluster_op.walk([&](Operation* op) {
76         for (Value operand : op->getOperands()) {
77           Operation* operand_op = GetOpOfValue(operand);
78           if (head_outside_compiled_ops.count(operand_op)) continue;
79 
80           if (operand_op->getParentRegion() == cluster_region)
81             return WalkResult::interrupt();
82         }
83         return WalkResult::advance();
84       });
85 
86       if (!walk_result.wasInterrupted())
87         head_outside_compiled_ops.insert(&cluster_op);
88     }
89   }
90 
91   for (auto head_outside_compiled_op :
92        llvm::reverse(head_outside_compiled_ops)) {
93     auto users = head_outside_compiled_op->getUsers();
94     if (users.empty() ||
95         HasOutsideCompilationAttribute(head_outside_compiled_op))
96       continue;
97 
98     bool should_expand_op_to_host_computation = true;
99     for (auto consumer_op : users) {
100       if (should_expand_op_to_host_computation &&
101           !HasOutsideCompilationAttribute(consumer_op)) {
102         should_expand_op_to_host_computation = false;
103         continue;
104       }
105     }
106 
107     if (should_expand_op_to_host_computation)
108       head_outside_compiled_op->setAttr(kXlaOutsideCompilationAttr,
109                                         builder->getStringAttr(""));
110   }
111 }
112 
113 struct TPUHostComputationExpansionPass
114     : public TF::TPUHostComputationExpansionPassBase<
115           TPUHostComputationExpansionPass> {
116   void runOnFunction() override;
117 };
118 
runOnFunction()119 void TPUHostComputationExpansionPass::runOnFunction() {
120   OpBuilder builder(&getContext());
121   getFunction().walk([&](tf_device::ClusterOp cluster) {
122     ExpandHeadOutsideCompiledOps(cluster, &builder);
123   });
124 }
125 
126 }  // anonymous namespace
127 
CreateTPUHostComputationExpansionPass()128 std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass() {
129   return std::make_unique<TPUHostComputationExpansionPass>();
130 }
131 
132 }  // namespace TFTPU
133 }  // namespace mlir
134