1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Block.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
38 #include "tensorflow/compiler/xla/client/sharding_builder.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40
41 namespace mlir {
42 namespace TFTPU {
43 namespace {
44
45 constexpr char kReplicateSharding[] = "";
46 constexpr char kShardingAttr[] = "mhlo.sharding";
47 constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning";
48
49 struct TPUShardingIdentificationPass
50 : public PassWrapper<TPUShardingIdentificationPass,
51 OperationPass<ModuleOp>> {
getArgumentmlir::TFTPU::__anonfb1b72e90111::TPUShardingIdentificationPass52 StringRef getArgument() const final {
53 return "tf-tpu-sharding-identification";
54 }
55
getDescriptionmlir::TFTPU::__anonfb1b72e90111::TPUShardingIdentificationPass56 StringRef getDescription() const final {
57 return "Identifies and handles inputs/outputs of TPU computation that is "
58 "sharded across logical cores.";
59 }
60
61 void runOnOperation() override;
62 };
63
64 // Returns XLA sharding from TPUPartitionedInput op connected to a
65 // `tf_device.cluster_func` operand value. If value is a resource type then
66 // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
67 // a `tf_device.cluster_func`.
GetXlaShardingFromOperand(Value value)68 llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) {
69 Value value_to_visit = value;
70 if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
71 value_to_visit.getDefiningOp()))
72 value_to_visit = read_var.resource();
73
74 if (auto partitioned_input =
75 llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
76 value_to_visit.getDefiningOp()))
77 return partitioned_input._XlaSharding();
78
79 return llvm::None;
80 }
81
82 // Returns XLA sharding from a XlaSharding op connected to an argument value. If
83 // value is a resource type then XlaSharding op will be connected to a
84 // ReadVariable op. XlaSharding op may be direct user of inputs but it may also
85 // be followed by an Identity op and, in the case where bfloat16 type is used,
86 // Cast op may be added right after the input.
87 //
88 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
89 // Case, While) ops and Caller return values.
90 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
91 // inputs.
GetXlaShardingFromArg(Value value)92 llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) {
93 llvm::SmallPtrSet<Value, 4> visited_values;
94 llvm::SmallVector<Value, 4> values_to_visit{value};
95 while (!values_to_visit.empty()) {
96 llvm::SmallVector<Value, 4> next_values_to_visit;
97 for (Value value_to_visit : values_to_visit) {
98 if (!visited_values.insert(value_to_visit).second) continue;
99
100 for (auto& use : value_to_visit.getUses()) {
101 Operation* owner = use.getOwner();
102 if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
103 return sharding._XlaSharding();
104
105 if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
106 next_values_to_visit.push_back(use.getOwner()->getResult(0));
107 continue;
108 }
109
110 if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) {
111 FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
112 if (!func) continue;
113 next_values_to_visit.push_back(
114 func.getArgument(use.getOperandNumber()));
115 }
116 }
117 }
118
119 values_to_visit.swap(next_values_to_visit);
120 }
121
122 return llvm::None;
123 }
124
125 // Extracts sharding configurations for all inputs by parsing XlaSharding/
126 // TPUPartitionedInput op connected to the operands/arguments. If argument to
127 // the `cluster_func` directly feeds into another function call op, then
128 // recursively walk the function definition to find the connected XlaSharding
129 // op.
IdentifyXlaShardingForComputationInputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args)130 void IdentifyXlaShardingForComputationInputs(
131 StringRef logical_core_0_sharding, bool use_spmd,
132 tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder,
133 llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) {
134 // Look up function definition from module.
135 Block& function_block = func.front();
136
137 sharding_for_args.reserve(function_block.getNumArguments());
138
139 // Iterate through operands of `cluster_func`.
140 // The computation operand can either be:
141 // 1) a TPUPartitionedInput Op if the input has a non-resource type;
142 // 2) a ReadVariableOp else.
143 //
144 // Replicate sharding is used if `use_spmd` is set.
145 //
146 // Iterate through input arguments to the entry block of
147 // tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
148 // XlaSharding ops can:
149 // 1) Directly follow the input argument if input argument has non-resource
150 // types.
151 // 2) Follow ReadVariableOp if the input type is of resource type.
152 // 3) Follow IdentityOp or CastOp after above cases (1), (2).
153 //
154 // Sharding configurations are added to the tf_device.ClusterFunc as an
155 // attribute and the function as an argument attribute.
156 for (auto operand_and_arg :
157 llvm::zip(cluster_func.operands(), function_block.getArguments())) {
158 Value operand = std::get<0>(operand_and_arg);
159 BlockArgument arg = std::get<1>(operand_and_arg);
160
161 if (auto operand_sharding = GetXlaShardingFromOperand(operand)) {
162 sharding_for_args.push_back(operand_sharding.getValue());
163 continue;
164 }
165
166 if (use_spmd) {
167 // If XLA SPMD is enabled, host variables or non-variable per-replica
168 // inputs should take on replicate sharding, unless another sharding is
169 // set via a TPUPartitionedInput op.
170 sharding_for_args.push_back(kReplicateSharding);
171 continue;
172 }
173
174 auto arg_sharding = GetXlaShardingFromArg(arg);
175 if (arg_sharding) {
176 sharding_for_args.push_back(arg_sharding.getValue());
177 continue;
178 }
179
180 // Default to maximal sharding core 0 if no sharding is present.
181 sharding_for_args.push_back(logical_core_0_sharding);
182 }
183 }
184
185 // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via
186 // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func`
187 // result value.
GetXlaShardingFromResult(Value value)188 llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) {
189 if (!value.hasOneUse()) return llvm::None;
190
191 Operation* user = *value.getUsers().begin();
192 if (auto partitioned_output =
193 llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
194 return partitioned_output._XlaSharding();
195
196 if (auto assign_var = llvm::dyn_cast<TF::AssignVariableOp>(user))
197 if (auto partitioned_input =
198 llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
199 assign_var.resource().getDefiningOp()))
200 return partitioned_input._XlaSharding();
201
202 return llvm::None;
203 }
204
205 // Returns XLA sharding from XlaSharding op connected to a result value.
206 // XlaSharding op may be direct user of inputs but it may also be followed by an
207 // Identity op and, in the case where bfloat16 type is used, Cast op may be
208 // added right after the input.
209 //
210 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
211 // Case, While) ops and Caller argument values.
212 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
213 // inputs.
GetXlaShardingFromRetval(Value value)214 llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) {
215 llvm::SmallPtrSet<Value, 4> visited_values;
216 Value value_to_visit = value;
217 while (value_to_visit) {
218 if (!visited_values.insert(value_to_visit).second) return llvm::None;
219
220 Operation* def = value_to_visit.getDefiningOp();
221 if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
222 return sharding._XlaSharding();
223
224 if (llvm::isa_and_nonnull<TF::IdentityOp, TF::CastOp>(def)) {
225 value_to_visit = def->getOperand(0);
226 continue;
227 }
228
229 if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
230 FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
231 if (!func) continue;
232 value_to_visit = func.front().getTerminator()->getOperand(
233 value_to_visit.cast<OpResult>().getResultNumber());
234 continue;
235 }
236
237 break;
238 }
239
240 return llvm::None;
241 }
242
243 // Extracts sharding configurations for all outputs by parsing XlaSharding/
244 // TPUPartitionedOutput op connected to the retvals/results.
IdentifyXlaShardingForComputationOutputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_rets)245 void IdentifyXlaShardingForComputationOutputs(
246 StringRef logical_core_0_sharding, bool use_spmd,
247 tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder,
248 llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) {
249 Block& function_block = func.front();
250 Operation* terminator = function_block.getTerminator();
251 sharding_for_rets.reserve(terminator->getNumOperands());
252
253 // Iterate through results of `cluster_func`. For output ops, look for
254 // TPUPartitionedOutput ops.
255 //
256 // Replicate sharding is used if `use_spmd` is set.
257 //
258 // Iterate through operands of the terminator. If the preceding op is
259 // XlaShardingOp, then the provided sharding configuration is added to the
260 // tf_device.ClusterFunc as an attribute and the function as a result
261 // attribute.
262 for (auto result_and_retval :
263 llvm::zip(cluster_func.results(), terminator->getOpOperands())) {
264 Value result = std::get<0>(result_and_retval);
265 OpOperand& retval = std::get<1>(result_and_retval);
266
267 if (auto result_sharding = GetXlaShardingFromResult(result)) {
268 sharding_for_rets.push_back(result_sharding.getValue());
269 continue;
270 }
271
272 if (use_spmd) {
273 // If XLA SPMD is enabled, outputs all should have replicate sharding,
274 // unless another sharding is set via a TPUPartitionedOutput op.
275 sharding_for_rets.push_back(kReplicateSharding);
276 continue;
277 }
278
279 if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) {
280 sharding_for_rets.push_back(retval_sharding.getValue());
281 continue;
282 }
283
284 // Default to maximal sharding core 0 if no sharding is present.
285 sharding_for_rets.push_back(logical_core_0_sharding);
286 }
287 }
288
289 // Extracts input/output sharding configuration of `cluster_func` by parsing
290 // XlaSharding ops inside the `cluster_func`.
IdentifyXlaShardingForTPUComputation(Builder * builder,tf_device::ClusterFuncOp cluster_func)291 void IdentifyXlaShardingForTPUComputation(
292 Builder* builder, tf_device::ClusterFuncOp cluster_func) {
293 // Look up function definition from module.
294 FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
295 cluster_func.func());
296
297 // By default inputs/outputs have maximal sharding and are assigned to logical
298 // core 0 if no sharding is defined.
299 const std::string logical_core_0_sharding =
300 xla::sharding_builder::AssignDevice(0).SerializeAsString();
301
302 bool use_spmd = false;
303 if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(kUseSpmdAttr))
304 use_spmd = use_spmd_attr.getValue();
305
306 llvm::SmallVector<llvm::StringRef, 8> sharding_for_args;
307 IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,
308 cluster_func, func, builder,
309 sharding_for_args);
310
311 llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets;
312 IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, use_spmd,
313 cluster_func, func, builder,
314 sharding_for_rets);
315
316 auto has_maximal_sharding = [](llvm::StringRef sharding_string) -> bool {
317 xla::OpSharding sharding;
318 sharding.ParseFromString(sharding_string.str());
319 return sharding.type() == xla::OpSharding::MAXIMAL;
320 };
321
322 // XLA SPMD only supports cases where all inputs/outputs exist on every
323 // partition (sharded or replicated). If any of the inputs/outputs have
324 // maximal sharding, then fallback to MPMD.
325 if (use_spmd && (absl::c_any_of(sharding_for_args, has_maximal_sharding) ||
326 absl::c_any_of(sharding_for_rets, has_maximal_sharding))) {
327 LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
328 "exist on every partition (sharded or replicated). If any "
329 "of the inputs/outputs have maximal sharding, then "
330 "fallback to MPMD.";
331 sharding_for_args.clear();
332 sharding_for_rets.clear();
333 cluster_func->setAttr(kUseSpmdAttr, builder->getBoolAttr(false));
334 IdentifyXlaShardingForComputationInputs(logical_core_0_sharding,
335 /*use_spmd=*/false, cluster_func,
336 func, builder, sharding_for_args);
337 IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding,
338 /*use_spmd=*/false, cluster_func,
339 func, builder, sharding_for_rets);
340 }
341
342 // Update sharding on function arguments and returns.
343 Block& function_block = func.front();
344 for (auto sharding_and_arg :
345 llvm::zip(sharding_for_args, function_block.getArguments())) {
346 StringRef sharding = std::get<0>(sharding_and_arg);
347 BlockArgument arg = std::get<1>(sharding_and_arg);
348 func.setArgAttr(arg.getArgNumber(), kShardingAttr,
349 builder->getStringAttr(sharding));
350 }
351
352 Operation* terminator = function_block.getTerminator();
353 for (auto sharding_and_retval :
354 llvm::zip(sharding_for_rets, terminator->getOpOperands())) {
355 StringRef sharding = std::get<0>(sharding_and_retval);
356 OpOperand& retval = std::get<1>(sharding_and_retval);
357 func.setResultAttr(retval.getOperandNumber(), kShardingAttr,
358 builder->getStringAttr(sharding));
359 }
360
361 // Update input/output sharding attributes on tf_device.cluster_func op.
362 cluster_func->setAttr(tensorflow::kInputShardingAttr,
363 builder->getStrArrayAttr(sharding_for_args));
364 cluster_func->setAttr(tensorflow::kOutputShardingAttr,
365 builder->getStrArrayAttr(sharding_for_rets));
366 }
367
runOnOperation()368 void TPUShardingIdentificationPass::runOnOperation() {
369 Builder builder(getOperation().getContext());
370
371 getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
372 IdentifyXlaShardingForTPUComputation(&builder, cluster_func);
373 });
374 }
375
376 } // anonymous namespace
377
CreateTPUShardingIdentificationPass()378 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
379 return std::make_unique<TPUShardingIdentificationPass>();
380 }
381
382 static PassRegistration<TPUShardingIdentificationPass> pass;
383
384 } // namespace TFTPU
385 } // namespace mlir
386