1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <tuple>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Block.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "mlir/IR/UseDefLists.h" // from @llvm-project
28 #include "mlir/IR/Value.h" // from @llvm-project
29 #include "mlir/Pass/Pass.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
34
35 namespace mlir {
36 namespace TFTPU {
37
38 namespace {
39
40 constexpr char kDeviceAttr[] = "device";
41 constexpr char kFuncDeviceAttr[] = "tf.device";
42
43 // Checks if a function only contains a tf_executor.graph.
IsSupportedGraph(func::FuncOp func)44 bool IsSupportedGraph(func::FuncOp func) {
45 if (!llvm::hasSingleElement(func)) return false;
46
47 Block& block = func.front();
48 if (!llvm::hasSingleElement(block.without_terminator())) return false;
49
50 auto graph = llvm::dyn_cast<tf_executor::GraphOp>(block.front());
51 if (!graph) return false;
52
53 Operation* terminator = block.getTerminator();
54 if (graph.getNumResults() != terminator->getNumOperands()) return false;
55 for (auto result : llvm::zip(graph.results(), terminator->getOperands()))
56 if (std::get<0>(result) != std::get<1>(result)) return false;
57
58 return true;
59 }
60
61 // Checks if an operation of the tf_executor dialect can have TPU devices
62 // propagated through.
IsSupportedExecutorOp(Operation & op)63 bool IsSupportedExecutorOp(Operation& op) {
64 auto ops_have_same_device = [](Operation* lhs, Operation* rhs) {
65 auto lhs_device_attr = lhs->getAttrOfType<StringAttr>(kDeviceAttr);
66 auto rhs_device_attr = rhs->getAttrOfType<StringAttr>(kDeviceAttr);
67 return (!lhs_device_attr && !rhs_device_attr) ||
68 (lhs_device_attr && rhs_device_attr &&
69 lhs_device_attr.getValue() == rhs_device_attr.getValue());
70 };
71
72 // Check if tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink
73 // pair has matching devices or no devices.
74 if (auto source = llvm::dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
75 return ops_have_same_device(source, source.GetSink());
76 } else if (auto sink = llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
77 return ops_have_same_device(sink.GetSource(), sink);
78 }
79
80 return llvm::isa<tf_executor::EnterOp, tf_executor::ExitOp,
81 tf_executor::IslandOp, tf_executor::MergeOp,
82 tf_executor::SwitchOp>(op);
83 }
84
85 // Assigns all data results to a specified device.
PopulateDeviceForOpResults(Operation & op,llvm::StringRef device,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)86 void PopulateDeviceForOpResults(
87 Operation& op, llvm::StringRef device,
88 llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
89 Operation* op_to_update = &op;
90 // Use tf_executor.island op if present as non v1 control flow op results are
91 // forwarded by a parent tf_executor.island op.
92 if (llvm::isa<tf_executor::IslandOp>(op_to_update->getParentOp()))
93 op_to_update = op_to_update->getParentOp();
94
95 for (Value result : op_to_update->getResults()) {
96 if (result.getType().isa<tf_executor::TokenType>()) continue;
97 if (result.getType().isa<tf_executor::ControlType>()) break;
98
99 value_to_device.insert({result, device});
100 }
101 }
102
103 // Checks if an operation can have TPU devices propagated through.
IsSupportedOpToSetDevice(Operation & op)104 bool IsSupportedOpToSetDevice(Operation& op) {
105 return IsSupportedExecutorOp(op) ||
106 isa<TF::IdentityOp, TF::IdentityNOp, TF::ShapeOp>(op);
107 }
108
109 // Finds nonconflicting TPU device for an operation from its operands. If an
110 // operand has no device or a non TPU device, or if there are conflicting
111 // devices, and empty StringRef will be returned. Control dependencies,
112 // NextIteration.Source -> NextIteration.Sink token dependencies, and
113 // LoopCond -> Switch data dependencies are ignored.
FindDeviceFromOperands(Operation & op,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)114 llvm::StringRef FindDeviceFromOperands(
115 Operation& op,
116 const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
117 llvm::StringRef new_device;
118 const bool is_switch = llvm::isa<tf_executor::SwitchOp>(op);
119 for (Value operand : op.getOperands()) {
120 if (operand.getType().isa<tf_executor::TokenType>()) continue;
121 if (operand.getType().isa<tf_executor::ControlType>()) break;
122
123 if (is_switch &&
124 llvm::isa_and_nonnull<tf_executor::LoopCondOp>(operand.getDefiningOp()))
125 continue;
126
127 auto it = value_to_device.find(operand);
128 if (it == value_to_device.end()) return llvm::StringRef();
129
130 if (new_device.empty()) {
131 new_device = it->getSecond();
132 continue;
133 }
134
135 if (new_device != it->getSecond()) return llvm::StringRef();
136 }
137
138 return new_device;
139 }
140
141 // Propagates devices from function arguments.
PropagateDevicesFromArguments(func::FuncOp func,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)142 void PropagateDevicesFromArguments(
143 func::FuncOp func,
144 llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
145 for (BlockArgument& arg : func.getArguments()) {
146 auto arg_device_attr =
147 func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), kFuncDeviceAttr);
148 if (!arg_device_attr || arg_device_attr.getValue().empty() ||
149 !tensorflow::IsTPUDevice(arg_device_attr.getValue()))
150 continue;
151 value_to_device.insert({arg, arg_device_attr.getValue()});
152 }
153 }
154
155 // Propagates devices from operation operands to results. Updating the device of
156 // a tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink will result
157 // in multiple passes over the tf_executor.graph to propagate devices in loops.
PropagateDevicesInGraph(tf_executor::GraphOp graph,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)158 void PropagateDevicesInGraph(
159 tf_executor::GraphOp graph,
160 llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
161 auto ops = graph.GetBody().without_terminator();
162
163 bool updated_next_iteration = false;
164 do {
165 updated_next_iteration = false;
166 for (Operation& op : ops) {
167 if (!IsSupportedExecutorOp(op)) continue;
168
169 Operation* op_to_update = &op;
170 // Unpack inner op of tf_executor.island.
171 if (auto island_op =
172 llvm::dyn_cast<tf_executor::IslandOp>(op_to_update)) {
173 if (!island_op.WrapsSingleOp()) continue;
174 op_to_update = &island_op.GetBody().front();
175 }
176
177 // If op already has a TPU device set, simply propagate its device.
178 auto device_attr = op_to_update->getAttrOfType<StringAttr>(kDeviceAttr);
179 const bool has_device = device_attr && !device_attr.getValue().empty();
180 if (has_device && tensorflow::IsTPUDevice(device_attr.getValue())) {
181 PopulateDeviceForOpResults(*op_to_update, device_attr.getValue(),
182 value_to_device);
183 continue;
184 }
185
186 // Op has an unsupported device.
187 if (has_device) continue;
188
189 if (!IsSupportedOpToSetDevice(*op_to_update)) continue;
190
191 llvm::StringRef new_device =
192 FindDeviceFromOperands(*op_to_update, value_to_device);
193 if (new_device.empty()) continue;
194
195 auto new_device_attr =
196 mlir::StringAttr::get(op_to_update->getContext(), new_device);
197 op_to_update->setAttr(kDeviceAttr, new_device_attr);
198 PopulateDeviceForOpResults(*op_to_update, new_device_attr.getValue(),
199 value_to_device);
200
201 if (auto sink =
202 llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
203 auto source = sink.GetSource();
204 source->setAttr(kDeviceAttr, new_device_attr);
205 PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
206 value_to_device);
207 updated_next_iteration = true;
208 }
209 }
210 } while (updated_next_iteration);
211 }
212
213 // Propagates devices to function results.
PropagateDevicesToResults(func::FuncOp func,tf_executor::FetchOp fetch,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)214 void PropagateDevicesToResults(
215 func::FuncOp func, tf_executor::FetchOp fetch,
216 const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
217 for (OpOperand& operand : fetch.getOperation()->getOpOperands()) {
218 if (operand.get().getType().isa<tf_executor::ControlType>()) break;
219 auto it = value_to_device.find(operand.get());
220 if (it != value_to_device.end()) {
221 auto device_attr = func.getResultAttrOfType<StringAttr>(
222 operand.getOperandNumber(), kFuncDeviceAttr);
223 if (device_attr && !device_attr.getValue().empty()) continue;
224 func.setResultAttr(operand.getOperandNumber(), kFuncDeviceAttr,
225 StringAttr::get(func.getContext(), it->getSecond()));
226 }
227 }
228 }
229
230 struct TPUDevicePropagation
231 : public TF::TPUDevicePropagationPassBase<TPUDevicePropagation> {
232 void runOnOperation() override;
233 };
234
runOnOperation()235 void TPUDevicePropagation::runOnOperation() {
236 func::FuncOp func = getOperation();
237 if (!IsSupportedGraph(func)) return;
238
239 llvm::DenseMap<Value, llvm::StringRef> value_to_device;
240 PropagateDevicesFromArguments(func, value_to_device);
241 auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
242 PropagateDevicesInGraph(graph, value_to_device);
243 PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
244 }
245
246 } // namespace
247
CreateTPUDevicePropagationPass()248 std::unique_ptr<OperationPass<func::FuncOp>> CreateTPUDevicePropagationPass() {
249 return std::make_unique<TPUDevicePropagation>();
250 }
251
252 } // namespace TFTPU
253 } // namespace mlir
254