1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <tuple>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Block.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/UseDefLists.h" // from @llvm-project
27 #include "mlir/IR/Value.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
32
33 namespace mlir {
34 namespace TFTPU {
35
36 namespace {
37
38 constexpr char kDeviceAttr[] = "device";
39 constexpr char kFuncDeviceAttr[] = "tf.device";
40
41 // Checks if a function only contains a tf_executor.graph.
IsSupportedGraph(FuncOp func)42 bool IsSupportedGraph(FuncOp func) {
43 if (!llvm::hasSingleElement(func)) return false;
44
45 Block& block = func.front();
46 if (!llvm::hasSingleElement(block.without_terminator())) return false;
47
48 auto graph = llvm::dyn_cast<tf_executor::GraphOp>(block.front());
49 if (!graph) return false;
50
51 Operation* terminator = block.getTerminator();
52 if (graph.getNumResults() != terminator->getNumOperands()) return false;
53 for (auto result : llvm::zip(graph.results(), terminator->getOperands()))
54 if (std::get<0>(result) != std::get<1>(result)) return false;
55
56 return true;
57 }
58
59 // Checks if an operation of the tf_executor dialect can have TPU devices
60 // propagated through.
IsSupportedExecutorOp(Operation & op)61 bool IsSupportedExecutorOp(Operation& op) {
62 auto ops_have_same_device = [](Operation* lhs, Operation* rhs) {
63 auto lhs_device_attr = lhs->getAttrOfType<StringAttr>(kDeviceAttr);
64 auto rhs_device_attr = rhs->getAttrOfType<StringAttr>(kDeviceAttr);
65 return (!lhs_device_attr && !rhs_device_attr) ||
66 (lhs_device_attr && rhs_device_attr &&
67 lhs_device_attr.getValue() == rhs_device_attr.getValue());
68 };
69
70 // Check if tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink
71 // pair has matching devices or no devices.
72 if (auto source = llvm::dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
73 return ops_have_same_device(source, source.GetSink());
74 } else if (auto sink = llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
75 return ops_have_same_device(sink.GetSource(), sink);
76 }
77
78 return llvm::isa<tf_executor::EnterOp, tf_executor::ExitOp,
79 tf_executor::IslandOp, tf_executor::MergeOp,
80 tf_executor::SwitchOp>(op);
81 }
82
83 // Assigns all data results to a specified device.
PopulateDeviceForOpResults(Operation & op,llvm::StringRef device,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)84 void PopulateDeviceForOpResults(
85 Operation& op, llvm::StringRef device,
86 llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
87 Operation* op_to_update = &op;
88 // Use tf_executor.island op if present as non v1 control flow op results are
89 // forwarded by a parent tf_executor.island op.
90 if (llvm::isa<tf_executor::IslandOp>(op_to_update->getParentOp()))
91 op_to_update = op_to_update->getParentOp();
92
93 for (Value result : op_to_update->getResults()) {
94 if (result.getType().isa<tf_executor::TokenType>()) continue;
95 if (result.getType().isa<tf_executor::ControlType>()) break;
96
97 value_to_device.insert({result, device});
98 }
99 }
100
101 // Checks if an operation can have TPU devices propagated through.
IsSupportedOpToSetDevice(Operation & op)102 bool IsSupportedOpToSetDevice(Operation& op) {
103 return IsSupportedExecutorOp(op) ||
104 isa<TF::IdentityOp, TF::IdentityNOp, TF::ShapeOp>(op);
105 }
106
107 // Finds nonconflicting TPU device for an operation from its operands. If an
108 // operand has no device or a non TPU device, or if there are conflicting
109 // devices, and empty StringRef will be returned. Control dependencies,
110 // NextIteration.Source -> NextIteration.Sink token dependencies, and
111 // LoopCond -> Switch data dependencies are ignored.
FindDeviceFromOperands(Operation & op,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)112 llvm::StringRef FindDeviceFromOperands(
113 Operation& op,
114 const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
115 llvm::StringRef new_device;
116 const bool is_switch = llvm::isa<tf_executor::SwitchOp>(op);
117 for (Value operand : op.getOperands()) {
118 if (operand.getType().isa<tf_executor::TokenType>()) continue;
119 if (operand.getType().isa<tf_executor::ControlType>()) break;
120
121 if (is_switch &&
122 llvm::isa_and_nonnull<tf_executor::LoopCondOp>(operand.getDefiningOp()))
123 continue;
124
125 auto it = value_to_device.find(operand);
126 if (it == value_to_device.end()) return llvm::StringRef();
127
128 if (new_device.empty()) {
129 new_device = it->getSecond();
130 continue;
131 }
132
133 if (new_device != it->getSecond()) return llvm::StringRef();
134 }
135
136 return new_device;
137 }
138
139 // Propagates devices from function arguments.
PropagateDevicesFromArguments(FuncOp func,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)140 void PropagateDevicesFromArguments(
141 FuncOp func, llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
142 for (BlockArgument& arg : func.getArguments()) {
143 auto arg_device_attr =
144 func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), kFuncDeviceAttr);
145 if (!arg_device_attr || arg_device_attr.getValue().empty() ||
146 !tensorflow::IsTPUDevice(arg_device_attr.getValue()))
147 continue;
148 value_to_device.insert({arg, arg_device_attr.getValue()});
149 }
150 }
151
152 // Propagates devices from operation operands to results. Updating the device of
153 // a tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink will result
154 // in multiple passes over the tf_executor.graph to propagate devices in loops.
PropagateDevicesInGraph(tf_executor::GraphOp graph,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)155 void PropagateDevicesInGraph(
156 tf_executor::GraphOp graph,
157 llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
158 auto ops = graph.GetBody().without_terminator();
159
160 bool updated_next_iteration = false;
161 do {
162 updated_next_iteration = false;
163 for (Operation& op : ops) {
164 if (!IsSupportedExecutorOp(op)) continue;
165
166 Operation* op_to_update = &op;
167 // Unpack inner op of tf_executor.island.
168 if (auto island_op =
169 llvm::dyn_cast<tf_executor::IslandOp>(op_to_update)) {
170 if (!island_op.WrapsSingleOp()) continue;
171 op_to_update = &island_op.GetBody().front();
172 }
173
174 // If op already has a TPU device set, simply propagate its device.
175 auto device_attr = op_to_update->getAttrOfType<StringAttr>(kDeviceAttr);
176 const bool has_device = device_attr && !device_attr.getValue().empty();
177 if (has_device && tensorflow::IsTPUDevice(device_attr.getValue())) {
178 PopulateDeviceForOpResults(*op_to_update, device_attr.getValue(),
179 value_to_device);
180 continue;
181 }
182
183 // Op has an unsupported device.
184 if (has_device) continue;
185
186 if (!IsSupportedOpToSetDevice(*op_to_update)) continue;
187
188 llvm::StringRef new_device =
189 FindDeviceFromOperands(*op_to_update, value_to_device);
190 if (new_device.empty()) continue;
191
192 auto new_device_attr =
193 mlir::StringAttr::get(op_to_update->getContext(), new_device);
194 op_to_update->setAttr(kDeviceAttr, new_device_attr);
195 PopulateDeviceForOpResults(*op_to_update, new_device_attr.getValue(),
196 value_to_device);
197
198 if (auto sink =
199 llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
200 auto source = sink.GetSource();
201 source->setAttr(kDeviceAttr, new_device_attr);
202 PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
203 value_to_device);
204 updated_next_iteration = true;
205 }
206 }
207 } while (updated_next_iteration);
208 }
209
210 // Propagates devices to function results.
PropagateDevicesToResults(FuncOp func,tf_executor::FetchOp fetch,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)211 void PropagateDevicesToResults(
212 FuncOp func, tf_executor::FetchOp fetch,
213 const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
214 for (OpOperand& operand : fetch.getOperation()->getOpOperands()) {
215 if (operand.get().getType().isa<tf_executor::ControlType>()) break;
216 auto it = value_to_device.find(operand.get());
217 if (it != value_to_device.end()) {
218 auto device_attr = func.getResultAttrOfType<StringAttr>(
219 operand.getOperandNumber(), kFuncDeviceAttr);
220 if (device_attr && !device_attr.getValue().empty()) continue;
221 func.setResultAttr(operand.getOperandNumber(), kFuncDeviceAttr,
222 StringAttr::get(func.getContext(), it->getSecond()));
223 }
224 }
225 }
226
227 struct TPUDevicePropagation
228 : public PassWrapper<TPUDevicePropagation, FunctionPass> {
229 void runOnFunction() override;
getArgumentmlir::TFTPU::__anonb7bea18d0111::TPUDevicePropagation230 StringRef getArgument() const final {
231 // This is the argument used to refer to the pass in
232 // the textual format (on the commandline for example).
233 return "tf-tpu-device-propagation";
234 }
getDescriptionmlir::TFTPU::__anonb7bea18d0111::TPUDevicePropagation235 StringRef getDescription() const final {
236 // This is a brief description of the pass.
237 return "Propagates TPU devices from ops to users";
238 }
239 };
240
runOnFunction()241 void TPUDevicePropagation::runOnFunction() {
242 FuncOp func = getFunction();
243 if (!IsSupportedGraph(func)) return;
244
245 llvm::DenseMap<Value, llvm::StringRef> value_to_device;
246 PropagateDevicesFromArguments(func, value_to_device);
247 auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
248 PropagateDevicesInGraph(graph, value_to_device);
249 PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
250 }
251
252 } // namespace
253
CreateTPUDevicePropagationPass()254 std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass() {
255 return std::make_unique<TPUDevicePropagation>();
256 }
257
258 static PassRegistration<TPUDevicePropagation> pass;
259
260 } // namespace TFTPU
261 } // namespace mlir
262