• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <tuple>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Block.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
32 
33 namespace mlir {
34 namespace TFTPU {
35 
36 namespace {
37 
38 constexpr char kDeviceAttr[] = "device";
39 constexpr char kFuncDeviceAttr[] = "tf.device";
40 
41 // Checks if a function only contains a tf_executor.graph.
IsSupportedGraph(FuncOp func)42 bool IsSupportedGraph(FuncOp func) {
43   if (!llvm::hasSingleElement(func)) return false;
44 
45   Block& block = func.front();
46   if (!llvm::hasSingleElement(block.without_terminator())) return false;
47 
48   auto graph = llvm::dyn_cast<tf_executor::GraphOp>(block.front());
49   if (!graph) return false;
50 
51   Operation* terminator = block.getTerminator();
52   if (graph.getNumResults() != terminator->getNumOperands()) return false;
53   for (auto result : llvm::zip(graph.results(), terminator->getOperands()))
54     if (std::get<0>(result) != std::get<1>(result)) return false;
55 
56   return true;
57 }
58 
59 // Checks if an operation of the tf_executor dialect can have TPU devices
60 // propagated through.
IsSupportedExecutorOp(Operation & op)61 bool IsSupportedExecutorOp(Operation& op) {
62   auto ops_have_same_device = [](Operation* lhs, Operation* rhs) {
63     auto lhs_device_attr = lhs->getAttrOfType<StringAttr>(kDeviceAttr);
64     auto rhs_device_attr = rhs->getAttrOfType<StringAttr>(kDeviceAttr);
65     return (!lhs_device_attr && !rhs_device_attr) ||
66            (lhs_device_attr && rhs_device_attr &&
67             lhs_device_attr.getValue() == rhs_device_attr.getValue());
68   };
69 
70   // Check if tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink
71   // pair has matching devices or no devices.
72   if (auto source = llvm::dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
73     return ops_have_same_device(source, source.GetSink());
74   } else if (auto sink = llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
75     return ops_have_same_device(sink.GetSource(), sink);
76   }
77 
78   return llvm::isa<tf_executor::EnterOp, tf_executor::ExitOp,
79                    tf_executor::IslandOp, tf_executor::MergeOp,
80                    tf_executor::SwitchOp>(op);
81 }
82 
83 // Assigns all data results to a specified device.
PopulateDeviceForOpResults(Operation & op,llvm::StringRef device,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)84 void PopulateDeviceForOpResults(
85     Operation& op, llvm::StringRef device,
86     llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
87   Operation* op_to_update = &op;
88   // Use tf_executor.island op if present as non v1 control flow op results are
89   // forwarded by a parent tf_executor.island op.
90   if (llvm::isa<tf_executor::IslandOp>(op_to_update->getParentOp()))
91     op_to_update = op_to_update->getParentOp();
92 
93   for (Value result : op_to_update->getResults()) {
94     if (result.getType().isa<tf_executor::TokenType>()) continue;
95     if (result.getType().isa<tf_executor::ControlType>()) break;
96 
97     value_to_device.insert({result, device});
98   }
99 }
100 
101 // Checks if an operation can have TPU devices propagated through.
IsSupportedOpToSetDevice(Operation & op)102 bool IsSupportedOpToSetDevice(Operation& op) {
103   return IsSupportedExecutorOp(op) ||
104          isa<TF::IdentityOp, TF::IdentityNOp, TF::ShapeOp>(op);
105 }
106 
107 // Finds nonconflicting TPU device for an operation from its operands. If an
108 // operand has no device or a non TPU device, or if there are conflicting
109 // devices, and empty StringRef will be returned. Control dependencies,
110 // NextIteration.Source -> NextIteration.Sink token dependencies, and
111 // LoopCond -> Switch data dependencies are ignored.
FindDeviceFromOperands(Operation & op,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)112 llvm::StringRef FindDeviceFromOperands(
113     Operation& op,
114     const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
115   llvm::StringRef new_device;
116   const bool is_switch = llvm::isa<tf_executor::SwitchOp>(op);
117   for (Value operand : op.getOperands()) {
118     if (operand.getType().isa<tf_executor::TokenType>()) continue;
119     if (operand.getType().isa<tf_executor::ControlType>()) break;
120 
121     if (is_switch &&
122         llvm::isa_and_nonnull<tf_executor::LoopCondOp>(operand.getDefiningOp()))
123       continue;
124 
125     auto it = value_to_device.find(operand);
126     if (it == value_to_device.end()) return llvm::StringRef();
127 
128     if (new_device.empty()) {
129       new_device = it->getSecond();
130       continue;
131     }
132 
133     if (new_device != it->getSecond()) return llvm::StringRef();
134   }
135 
136   return new_device;
137 }
138 
139 // Propagates devices from function arguments.
PropagateDevicesFromArguments(FuncOp func,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)140 void PropagateDevicesFromArguments(
141     FuncOp func, llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
142   for (BlockArgument& arg : func.getArguments()) {
143     auto arg_device_attr =
144         func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), kFuncDeviceAttr);
145     if (!arg_device_attr || arg_device_attr.getValue().empty() ||
146         !tensorflow::IsTPUDevice(arg_device_attr.getValue()))
147       continue;
148     value_to_device.insert({arg, arg_device_attr.getValue()});
149   }
150 }
151 
152 // Propagates devices from operation operands to results. Updating the device of
153 // a tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink will result
154 // in multiple passes over the tf_executor.graph to propagate devices in loops.
PropagateDevicesInGraph(tf_executor::GraphOp graph,llvm::DenseMap<Value,llvm::StringRef> & value_to_device)155 void PropagateDevicesInGraph(
156     tf_executor::GraphOp graph,
157     llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
158   auto ops = graph.GetBody().without_terminator();
159 
160   bool updated_next_iteration = false;
161   do {
162     updated_next_iteration = false;
163     for (Operation& op : ops) {
164       if (!IsSupportedExecutorOp(op)) continue;
165 
166       Operation* op_to_update = &op;
167       // Unpack inner op of tf_executor.island.
168       if (auto island_op =
169               llvm::dyn_cast<tf_executor::IslandOp>(op_to_update)) {
170         if (!island_op.WrapsSingleOp()) continue;
171         op_to_update = &island_op.GetBody().front();
172       }
173 
174       // If op already has a TPU device set, simply propagate its device.
175       auto device_attr = op_to_update->getAttrOfType<StringAttr>(kDeviceAttr);
176       const bool has_device = device_attr && !device_attr.getValue().empty();
177       if (has_device && tensorflow::IsTPUDevice(device_attr.getValue())) {
178         PopulateDeviceForOpResults(*op_to_update, device_attr.getValue(),
179                                    value_to_device);
180         continue;
181       }
182 
183       // Op has an unsupported device.
184       if (has_device) continue;
185 
186       if (!IsSupportedOpToSetDevice(*op_to_update)) continue;
187 
188       llvm::StringRef new_device =
189           FindDeviceFromOperands(*op_to_update, value_to_device);
190       if (new_device.empty()) continue;
191 
192       auto new_device_attr =
193           mlir::StringAttr::get(op_to_update->getContext(), new_device);
194       op_to_update->setAttr(kDeviceAttr, new_device_attr);
195       PopulateDeviceForOpResults(*op_to_update, new_device_attr.getValue(),
196                                  value_to_device);
197 
198       if (auto sink =
199               llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
200         auto source = sink.GetSource();
201         source->setAttr(kDeviceAttr, new_device_attr);
202         PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
203                                    value_to_device);
204         updated_next_iteration = true;
205       }
206     }
207   } while (updated_next_iteration);
208 }
209 
210 // Propagates devices to function results.
PropagateDevicesToResults(FuncOp func,tf_executor::FetchOp fetch,const llvm::DenseMap<Value,llvm::StringRef> & value_to_device)211 void PropagateDevicesToResults(
212     FuncOp func, tf_executor::FetchOp fetch,
213     const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
214   for (OpOperand& operand : fetch.getOperation()->getOpOperands()) {
215     if (operand.get().getType().isa<tf_executor::ControlType>()) break;
216     auto it = value_to_device.find(operand.get());
217     if (it != value_to_device.end()) {
218       auto device_attr = func.getResultAttrOfType<StringAttr>(
219           operand.getOperandNumber(), kFuncDeviceAttr);
220       if (device_attr && !device_attr.getValue().empty()) continue;
221       func.setResultAttr(operand.getOperandNumber(), kFuncDeviceAttr,
222                          StringAttr::get(func.getContext(), it->getSecond()));
223     }
224   }
225 }
226 
227 struct TPUDevicePropagation
228     : public PassWrapper<TPUDevicePropagation, FunctionPass> {
229   void runOnFunction() override;
getArgumentmlir::TFTPU::__anonb7bea18d0111::TPUDevicePropagation230   StringRef getArgument() const final {
231     // This is the argument used to refer to the pass in
232     // the textual format (on the commandline for example).
233     return "tf-tpu-device-propagation";
234   }
getDescriptionmlir::TFTPU::__anonb7bea18d0111::TPUDevicePropagation235   StringRef getDescription() const final {
236     // This is a brief description of the pass.
237     return "Propagates TPU devices from ops to users";
238   }
239 };
240 
runOnFunction()241 void TPUDevicePropagation::runOnFunction() {
242   FuncOp func = getFunction();
243   if (!IsSupportedGraph(func)) return;
244 
245   llvm::DenseMap<Value, llvm::StringRef> value_to_device;
246   PropagateDevicesFromArguments(func, value_to_device);
247   auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
248   PropagateDevicesInGraph(graph, value_to_device);
249   PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
250 }
251 
252 }  // namespace
253 
CreateTPUDevicePropagationPass()254 std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass() {
255   return std::make_unique<TPUDevicePropagation>();
256 }
257 
258 static PassRegistration<TPUDevicePropagation> pass;
259 
260 }  // namespace TFTPU
261 }  // namespace mlir
262