• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <tuple>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Block.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
28 #include "mlir/IR/Value.h"  // from @llvm-project
29 #include "mlir/Pass/Pass.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
34 
35 namespace mlir {
36 namespace TFTPU {
37 
38 namespace {
39 
40 constexpr char kDeviceAttr[] = "device";
41 constexpr char kFuncDeviceAttr[] = "tf.device";
42 
43 // Checks if a function only contains a tf_executor.graph.
IsSupportedGraph(func::FuncOp func)44 bool IsSupportedGraph(func::FuncOp func) {
45   if (!llvm::hasSingleElement(func)) return false;
46 
47   Block& block = func.front();
48   if (!llvm::hasSingleElement(block.without_terminator())) return false;
49 
50   auto graph = llvm::dyn_cast<tf_executor::GraphOp>(block.front());
51   if (!graph) return false;
52 
53   Operation* terminator = block.getTerminator();
54   if (graph.getNumResults() != terminator->getNumOperands()) return false;
55   for (auto result : llvm::zip(graph.results(), terminator->getOperands()))
56     if (std::get<0>(result) != std::get<1>(result)) return false;
57 
58   return true;
59 }
60 
61 // Checks if an operation of the tf_executor dialect can have TPU devices
62 // propagated through.
IsSupportedExecutorOp(Operation & op)63 bool IsSupportedExecutorOp(Operation& op) {
64   auto ops_have_same_device = [](Operation* lhs, Operation* rhs) {
65     auto lhs_device_attr = lhs->getAttrOfType<StringAttr>(kDeviceAttr);
66     auto rhs_device_attr = rhs->getAttrOfType<StringAttr>(kDeviceAttr);
67     return (!lhs_device_attr && !rhs_device_attr) ||
68            (lhs_device_attr && rhs_device_attr &&
69             lhs_device_attr.getValue() == rhs_device_attr.getValue());
70   };
71 
72   // Check if tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink
73   // pair has matching devices or no devices.
74   if (auto source = llvm::dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
75     return ops_have_same_device(source, source.GetSink());
76   } else if (auto sink = llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
77     return ops_have_same_device(sink.GetSource(), sink);
78   }
79 
80   return llvm::isa<tf_executor::EnterOp, tf_executor::ExitOp,
81                    tf_executor::IslandOp, tf_executor::MergeOp,
82                    tf_executor::SwitchOp>(op);
83 }
84 
85 // Assigns all data results to a specified device.
PopulateDeviceForOpResults(Operation & op,llvm::StringRef device,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)86 void PopulateDeviceForOpResults(
87     Operation& op, llvm::StringRef device,
88     llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
89   Operation* op_to_update = &op;
90   // Use tf_executor.island op if present as non v1 control flow op results are
91   // forwarded by a parent tf_executor.island op.
92   if (llvm::isa<tf_executor::IslandOp>(op_to_update->getParentOp()))
93     op_to_update = op_to_update->getParentOp();
94 
95   for (Value result : op_to_update->getResults()) {
96     if (result.getType().isa<tf_executor::TokenType>()) continue;
97     if (result.getType().isa<tf_executor::ControlType>()) break;
98 
99     value_to_device.insert({result, device});
100   }
101 }
102 
103 // Checks if an operation can have TPU devices propagated through.
IsSupportedOpToSetDevice(Operation & op)104 bool IsSupportedOpToSetDevice(Operation& op) {
105   return IsSupportedExecutorOp(op) ||
106          isa<TF::IdentityOp, TF::IdentityNOp, TF::ShapeOp>(op);
107 }
108 
109 // Finds nonconflicting TPU device for an operation from its operands. If an
110 // operand has no device or a non TPU device, or if there are conflicting
111 // devices, and empty StringRef will be returned. Control dependencies,
112 // NextIteration.Source -> NextIteration.Sink token dependencies, and
113 // LoopCond -> Switch data dependencies are ignored.
FindDeviceFromOperands(Operation & op,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)114 llvm::StringRef FindDeviceFromOperands(
115     Operation& op,
116     const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
117   llvm::StringRef new_device;
118   const bool is_switch = llvm::isa<tf_executor::SwitchOp>(op);
119   for (Value operand : op.getOperands()) {
120     if (operand.getType().isa<tf_executor::TokenType>()) continue;
121     if (operand.getType().isa<tf_executor::ControlType>()) break;
122 
123     if (is_switch &&
124         llvm::isa_and_nonnull<tf_executor::LoopCondOp>(operand.getDefiningOp()))
125       continue;
126 
127     auto it = value_to_device.find(operand);
128     if (it == value_to_device.end()) return llvm::StringRef();
129 
130     if (new_device.empty()) {
131       new_device = it->getSecond();
132       continue;
133     }
134 
135     if (new_device != it->getSecond()) return llvm::StringRef();
136   }
137 
138   return new_device;
139 }
140 
141 // Propagates devices from function arguments.
PropagateDevicesFromArguments(func::FuncOp func,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)142 void PropagateDevicesFromArguments(
143     func::FuncOp func,
144     llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
145   for (BlockArgument& arg : func.getArguments()) {
146     auto arg_device_attr =
147         func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), kFuncDeviceAttr);
148     if (!arg_device_attr || arg_device_attr.getValue().empty() ||
149         !tensorflow::IsTPUDevice(arg_device_attr.getValue()))
150       continue;
151     value_to_device.insert({arg, arg_device_attr.getValue()});
152   }
153 }
154 
155 // Propagates devices from operation operands to results. Updating the device of
156 // a tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink will result
157 // in multiple passes over the tf_executor.graph to propagate devices in loops.
PropagateDevicesInGraph(tf_executor::GraphOp graph,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)158 void PropagateDevicesInGraph(
159     tf_executor::GraphOp graph,
160     llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
161   auto ops = graph.GetBody().without_terminator();
162 
163   bool updated_next_iteration = false;
164   do {
165     updated_next_iteration = false;
166     for (Operation& op : ops) {
167       if (!IsSupportedExecutorOp(op)) continue;
168 
169       Operation* op_to_update = &op;
170       // Unpack inner op of tf_executor.island.
171       if (auto island_op =
172               llvm::dyn_cast<tf_executor::IslandOp>(op_to_update)) {
173         if (!island_op.WrapsSingleOp()) continue;
174         op_to_update = &island_op.GetBody().front();
175       }
176 
177       // If op already has a TPU device set, simply propagate its device.
178       auto device_attr = op_to_update->getAttrOfType<StringAttr>(kDeviceAttr);
179       const bool has_device = device_attr && !device_attr.getValue().empty();
180       if (has_device && tensorflow::IsTPUDevice(device_attr.getValue())) {
181         PopulateDeviceForOpResults(*op_to_update, device_attr.getValue(),
182                                    value_to_device);
183         continue;
184       }
185 
186       // Op has an unsupported device.
187       if (has_device) continue;
188 
189       if (!IsSupportedOpToSetDevice(*op_to_update)) continue;
190 
191       llvm::StringRef new_device =
192           FindDeviceFromOperands(*op_to_update, value_to_device);
193       if (new_device.empty()) continue;
194 
195       auto new_device_attr =
196           mlir::StringAttr::get(op_to_update->getContext(), new_device);
197       op_to_update->setAttr(kDeviceAttr, new_device_attr);
198       PopulateDeviceForOpResults(*op_to_update, new_device_attr.getValue(),
199                                  value_to_device);
200 
201       if (auto sink =
202               llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
203         auto source = sink.GetSource();
204         source->setAttr(kDeviceAttr, new_device_attr);
205         PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
206                                    value_to_device);
207         updated_next_iteration = true;
208       }
209     }
210   } while (updated_next_iteration);
211 }
212 
213 // Propagates devices to function results.
PropagateDevicesToResults(func::FuncOp func,tf_executor::FetchOp fetch,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)214 void PropagateDevicesToResults(
215     func::FuncOp func, tf_executor::FetchOp fetch,
216     const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
217   for (OpOperand& operand : fetch.getOperation()->getOpOperands()) {
218     if (operand.get().getType().isa<tf_executor::ControlType>()) break;
219     auto it = value_to_device.find(operand.get());
220     if (it != value_to_device.end()) {
221       auto device_attr = func.getResultAttrOfType<StringAttr>(
222           operand.getOperandNumber(), kFuncDeviceAttr);
223       if (device_attr && !device_attr.getValue().empty()) continue;
224       func.setResultAttr(operand.getOperandNumber(), kFuncDeviceAttr,
225                          StringAttr::get(func.getContext(), it->getSecond()));
226     }
227   }
228 }
229 
230 struct TPUDevicePropagation
231     : public TF::TPUDevicePropagationPassBase<TPUDevicePropagation> {
232   void runOnOperation() override;
233 };
234 
runOnOperation()235 void TPUDevicePropagation::runOnOperation() {
236   func::FuncOp func = getOperation();
237   if (!IsSupportedGraph(func)) return;
238 
239   llvm::DenseMap<Value, llvm::StringRef> value_to_device;
240   PropagateDevicesFromArguments(func, value_to_device);
241   auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
242   PropagateDevicesInGraph(graph, value_to_device);
243   PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
244 }
245 
246 }  // namespace
247 
CreateTPUDevicePropagationPass()248 std::unique_ptr<OperationPass<func::FuncOp>> CreateTPUDevicePropagationPass() {
249   return std::make_unique<TPUDevicePropagation>();
250 }
251 
252 }  // namespace TFTPU
253 }  // namespace mlir
254