1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Block.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
38 #include "tensorflow/compiler/xla/client/sharding_builder.h"
39
40 namespace mlir {
41 namespace TFTPU {
42 namespace {
43
44 constexpr char kShardingAttr[] = "mhlo.sharding";
45 constexpr char kReplicateSharding[] = "";
46
47 struct TPUShardingIdentificationPass
48 : public PassWrapper<TPUShardingIdentificationPass,
49 OperationPass<ModuleOp>> {
50 void runOnOperation() override;
51 };
52
53 // Returns XLA sharding from TPUPartitionedInput op connected to a
54 // `tf_device.cluster_func` operand value. If value is a resource type then
55 // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
56 // a `tf_device.cluster_func`.
GetXlaShardingFromOperand(Value value)57 llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) {
58 Value value_to_visit = value;
59 if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
60 value_to_visit.getDefiningOp()))
61 value_to_visit = read_var.resource();
62
63 if (auto partitioned_input =
64 llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
65 value_to_visit.getDefiningOp()))
66 return partitioned_input._XlaSharding();
67
68 return llvm::None;
69 }
70
71 // Returns XLA sharding from a XlaSharding op connected to an argument value. If
72 // value is a resource type then XlaSharding op will be connected to a
73 // ReadVariable op. XlaSharding op may be direct user of inputs but it may also
74 // be followed by an Identity op and, in the case where bfloat16 type is used,
75 // Cast op may be added right after the input.
76 //
77 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
78 // Case, While) ops and Caller return values.
79 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
80 // inputs.
GetXlaShardingFromArg(Value value)81 llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) {
82 llvm::SmallPtrSet<Value, 4> visited_values;
83 llvm::SmallVector<Value, 4> values_to_visit{value};
84 while (!values_to_visit.empty()) {
85 llvm::SmallVector<Value, 4> next_values_to_visit;
86 for (Value value_to_visit : values_to_visit) {
87 if (!visited_values.insert(value_to_visit).second) continue;
88
89 for (auto& use : value_to_visit.getUses()) {
90 Operation* owner = use.getOwner();
91 if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
92 return sharding._XlaSharding();
93
94 if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
95 next_values_to_visit.push_back(use.getOwner()->getResult(0));
96 continue;
97 }
98
99 if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) {
100 FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
101 if (!func) continue;
102 next_values_to_visit.push_back(
103 func.getArgument(use.getOperandNumber()));
104 }
105 }
106 }
107
108 values_to_visit.swap(next_values_to_visit);
109 }
110
111 return llvm::None;
112 }
113
114 // Extracts sharding configurations for all inputs by parsing XlaSharding/
115 // TPUPartitionedInput op connected to the operands/arguments. If argument to
116 // the `cluster_func` directly feeds into another function call op, then
117 // recursively walk the function definition to find the connected XlaSharding
118 // op.
IdentifyXlaShardingForComputationInputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder)119 void IdentifyXlaShardingForComputationInputs(
120 StringRef logical_core_0_sharding, bool use_spmd,
121 tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder) {
122 // Look up function definition from module.
123 Block& function_block = func.front();
124
125 llvm::SmallVector<llvm::StringRef, 8> sharding_for_args;
126 sharding_for_args.reserve(function_block.getNumArguments());
127
128 // Iterate through operands of `cluster_func`.
129 // The computation operand can either be:
130 // 1) a TPUPartitionedInput Op if the input has a non-resource type;
131 // 2) a ReadVariableOp else.
132 //
133 // Replicate sharding is used if `use_spmd` is set.
134 //
135 // Iterate through input arguments to the entry block of
136 // tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
137 // XlaSharding ops can:
138 // 1) Directly follow the input argument if input argument has non-resource
139 // types.
140 // 2) Follow ReadVariableOp if the input type is of resource type.
141 // 3) Follow IdentityOp or CastOp after above cases (1), (2).
142 //
143 // Sharding configurations are added to the tf_device.ClusterFunc as an
144 // attribute and the function as an argument attribute.
145 for (auto operand_and_arg :
146 llvm::zip(cluster_func.operands(), function_block.getArguments())) {
147 Value operand = std::get<0>(operand_and_arg);
148 BlockArgument arg = std::get<1>(operand_and_arg);
149 const int index = arg.getArgNumber();
150
151 if (auto operand_sharding = GetXlaShardingFromOperand(operand)) {
152 sharding_for_args.push_back(operand_sharding.getValue());
153 func.setArgAttr(index, kShardingAttr,
154 builder->getStringAttr(operand_sharding.getValue()));
155 continue;
156 }
157
158 if (use_spmd) {
159 // If XLA SPMD is enabled, host variables or non-variable per-replica
160 // inputs should take on replicate sharding, unless another sharding is
161 // set via a TPUPartitionedInput op.
162 sharding_for_args.push_back(kReplicateSharding);
163 func.setArgAttr(index, kShardingAttr,
164 builder->getStringAttr(kReplicateSharding));
165 continue;
166 }
167
168 auto arg_sharding = GetXlaShardingFromArg(arg);
169 if (arg_sharding) {
170 sharding_for_args.push_back(arg_sharding.getValue());
171 func.setArgAttr(index, kShardingAttr,
172 builder->getStringAttr(arg_sharding.getValue()));
173 continue;
174 }
175
176 // Default to maximal sharding core 0 if no sharding is present.
177 sharding_for_args.push_back(logical_core_0_sharding);
178 func.setArgAttr(index, kShardingAttr,
179 builder->getStringAttr(logical_core_0_sharding));
180 }
181
182 cluster_func->setAttr(tensorflow::kInputShardingAttr,
183 builder->getStrArrayAttr(sharding_for_args));
184 }
185
186 // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via
187 // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func`
188 // result value.
GetXlaShardingFromResult(Value value)189 llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) {
190 if (!value.hasOneUse()) return llvm::None;
191
192 Operation* user = *value.getUsers().begin();
193 if (auto partitioned_output =
194 llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
195 return partitioned_output._XlaSharding();
196
197 if (auto assign_var = llvm::dyn_cast<TF::AssignVariableOp>(user))
198 if (auto partitioned_input =
199 llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
200 assign_var.resource().getDefiningOp()))
201 return partitioned_input._XlaSharding();
202
203 return llvm::None;
204 }
205
206 // Returns XLA sharding from XlaSharding op connected to a result value.
207 // XlaSharding op may be direct user of inputs but it may also be followed by an
208 // Identity op and, in the case where bfloat16 type is used, Cast op may be
209 // added right after the input.
210 //
211 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
212 // Case, While) ops and Caller argument values.
213 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
214 // inputs.
GetXlaShardingFromRetval(Value value)215 llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) {
216 llvm::SmallPtrSet<Value, 4> visited_values;
217 Value value_to_visit = value;
218 while (value_to_visit) {
219 if (!visited_values.insert(value_to_visit).second) return llvm::None;
220
221 Operation* def = value_to_visit.getDefiningOp();
222 if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
223 return sharding._XlaSharding();
224
225 if (llvm::isa_and_nonnull<TF::IdentityOp, TF::CastOp>(def)) {
226 value_to_visit = def->getOperand(0);
227 continue;
228 }
229
230 if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
231 FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
232 if (!func) continue;
233 value_to_visit = func.front().getTerminator()->getOperand(
234 value_to_visit.cast<OpResult>().getResultNumber());
235 continue;
236 }
237
238 break;
239 }
240
241 return llvm::None;
242 }
243
244 // Extracts sharding configurations for all outputs by parsing XlaSharding/
245 // TPUPartitionedOutput op connected to the retvals/results.
IdentifyXlaShardingForComputationOutputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder)246 void IdentifyXlaShardingForComputationOutputs(
247 StringRef logical_core_0_sharding, bool use_spmd,
248 tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder) {
249 Block& function_block = func.front();
250 Operation* terminator = function_block.getTerminator();
251 llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets;
252 sharding_for_rets.reserve(terminator->getNumOperands());
253
254 // Iterate through results of `cluster_func`. For output ops, look for
255 // TPUPartitionedOutput ops.
256 //
257 // Replicate sharding is used if `use_spmd` is set.
258 //
259 // Iterate through operands of the terminator. If the preceding op is
260 // XlaShardingOp, then the provided sharding configuration is added to the
261 // tf_device.ClusterFunc as an attribute and the function as a result
262 // attribute.
263 for (auto result_and_retval :
264 llvm::zip(cluster_func.results(), terminator->getOpOperands())) {
265 Value result = std::get<0>(result_and_retval);
266 OpOperand& retval = std::get<1>(result_and_retval);
267 const int index = retval.getOperandNumber();
268
269 if (auto result_sharding = GetXlaShardingFromResult(result)) {
270 sharding_for_rets.push_back(result_sharding.getValue());
271 func.setResultAttr(index, kShardingAttr,
272 builder->getStringAttr(result_sharding.getValue()));
273 continue;
274 }
275
276 if (use_spmd) {
277 // If XLA SPMD is enabled, outputs all should have replicate sharding,
278 // unless another sharding is set via a TPUPartitionedOutput op.
279 sharding_for_rets.push_back(kReplicateSharding);
280 func.setResultAttr(index, kShardingAttr,
281 builder->getStringAttr(kReplicateSharding));
282 continue;
283 }
284
285 if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) {
286 sharding_for_rets.push_back(retval_sharding.getValue());
287 func.setResultAttr(index, kShardingAttr,
288 builder->getStringAttr(retval_sharding.getValue()));
289 continue;
290 }
291
292 // Default to maximal sharding core 0 if no sharding is present.
293 sharding_for_rets.push_back(logical_core_0_sharding);
294 func.setResultAttr(index, kShardingAttr,
295 builder->getStringAttr(logical_core_0_sharding));
296 }
297
298 cluster_func->setAttr(tensorflow::kOutputShardingAttr,
299 builder->getStrArrayAttr(sharding_for_rets));
300 }
301
302 // Extracts input/output sharding configuration of `cluster_func` by parsing
303 // XlaSharding ops inside the `cluster_func`.
IdentifyXlaShardingForTPUComputation(Builder * builder,tf_device::ClusterFuncOp cluster_func)304 void IdentifyXlaShardingForTPUComputation(
305 Builder* builder, tf_device::ClusterFuncOp cluster_func) {
306 // Look up function definition from module.
307 FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
308 cluster_func.func());
309
310 // By default inputs/outputs have maximal sharding and are assigned to logical
311 // core 0 if no sharding is defined.
312 const std::string logical_core_0_sharding =
313 xla::sharding_builder::AssignDevice(0).SerializeAsString();
314
315 bool use_spmd = false;
316 if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
317 "use_spmd_for_xla_partitioning"))
318 use_spmd = use_spmd_attr.getValue();
319
320 IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,
321 cluster_func, func, builder);
322
323 IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, use_spmd,
324 cluster_func, func, builder);
325 }
326
runOnOperation()327 void TPUShardingIdentificationPass::runOnOperation() {
328 Builder builder(getOperation().getContext());
329
330 getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
331 IdentifyXlaShardingForTPUComputation(&builder, cluster_func);
332 });
333 }
334
335 } // anonymous namespace
336
CreateTPUShardingIdentificationPass()337 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
338 return std::make_unique<TPUShardingIdentificationPass>();
339 }
340
341 static PassRegistration<TPUShardingIdentificationPass> pass(
342 "tf-tpu-sharding-identification",
343 "Identifies and handles inputs/outputs of TPU computation that is "
344 "sharded across logical cores.");
345
346 } // namespace TFTPU
347 } // namespace mlir
348