• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Block.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
38 #include "tensorflow/compiler/xla/client/sharding_builder.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 
41 namespace mlir {
42 namespace TFTPU {
43 namespace {
44 
45 constexpr char kReplicateSharding[] = "";
46 constexpr char kShardingAttr[] = "mhlo.sharding";
47 constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning";
48 
49 struct TPUShardingIdentificationPass
50     : public PassWrapper<TPUShardingIdentificationPass,
51                          OperationPass<ModuleOp>> {
getArgumentmlir::TFTPU::__anonfb1b72e90111::TPUShardingIdentificationPass52   StringRef getArgument() const final {
53     return "tf-tpu-sharding-identification";
54   }
55 
getDescriptionmlir::TFTPU::__anonfb1b72e90111::TPUShardingIdentificationPass56   StringRef getDescription() const final {
57     return "Identifies and handles inputs/outputs of TPU computation that is "
58            "sharded across logical cores.";
59   }
60 
61   void runOnOperation() override;
62 };
63 
64 // Returns XLA sharding from TPUPartitionedInput op connected to a
65 // `tf_device.cluster_func` operand value. If value is a resource type then
66 // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
67 // a `tf_device.cluster_func`.
GetXlaShardingFromOperand(Value value)68 llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) {
69   Value value_to_visit = value;
70   if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
71           value_to_visit.getDefiningOp()))
72     value_to_visit = read_var.resource();
73 
74   if (auto partitioned_input =
75           llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
76               value_to_visit.getDefiningOp()))
77     return partitioned_input._XlaSharding();
78 
79   return llvm::None;
80 }
81 
82 // Returns XLA sharding from a XlaSharding op connected to an argument value. If
83 // value is a resource type then XlaSharding op will be connected to a
84 // ReadVariable op. XlaSharding op may be direct user of inputs but it may also
85 // be followed by an Identity op and, in the case where bfloat16 type is used,
86 // Cast op may be added right after the input.
87 //
88 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
89 // Case, While) ops and Caller return values.
90 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
91 // inputs.
GetXlaShardingFromArg(Value value)92 llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) {
93   llvm::SmallPtrSet<Value, 4> visited_values;
94   llvm::SmallVector<Value, 4> values_to_visit{value};
95   while (!values_to_visit.empty()) {
96     llvm::SmallVector<Value, 4> next_values_to_visit;
97     for (Value value_to_visit : values_to_visit) {
98       if (!visited_values.insert(value_to_visit).second) continue;
99 
100       for (auto& use : value_to_visit.getUses()) {
101         Operation* owner = use.getOwner();
102         if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
103           return sharding._XlaSharding();
104 
105         if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
106           next_values_to_visit.push_back(use.getOwner()->getResult(0));
107           continue;
108         }
109 
110         if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) {
111           FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
112           if (!func) continue;
113           next_values_to_visit.push_back(
114               func.getArgument(use.getOperandNumber()));
115         }
116       }
117     }
118 
119     values_to_visit.swap(next_values_to_visit);
120   }
121 
122   return llvm::None;
123 }
124 
125 // Extracts sharding configurations for all inputs by parsing XlaSharding/
126 // TPUPartitionedInput op connected to the operands/arguments. If argument to
127 // the `cluster_func` directly feeds into another function call op, then
128 // recursively walk the function definition to find the connected XlaSharding
129 // op.
IdentifyXlaShardingForComputationInputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args)130 void IdentifyXlaShardingForComputationInputs(
131     StringRef logical_core_0_sharding, bool use_spmd,
132     tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder,
133     llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) {
134   // Look up function definition from module.
135   Block& function_block = func.front();
136 
137   sharding_for_args.reserve(function_block.getNumArguments());
138 
139   // Iterate through operands of `cluster_func`.
140   // The computation operand can either be:
141   //   1) a TPUPartitionedInput Op if the input has a non-resource type;
142   //   2) a ReadVariableOp else.
143   //
144   // Replicate sharding is used if `use_spmd` is set.
145   //
146   // Iterate through input arguments to the entry block of
147   // tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
148   // XlaSharding ops can:
149   //   1) Directly follow the input argument if input argument has non-resource
150   //      types.
151   //   2) Follow ReadVariableOp if the input type is of resource type.
152   //   3) Follow IdentityOp or CastOp after above cases (1), (2).
153   //
154   // Sharding configurations are added to the tf_device.ClusterFunc as an
155   // attribute and the function as an argument attribute.
156   for (auto operand_and_arg :
157        llvm::zip(cluster_func.operands(), function_block.getArguments())) {
158     Value operand = std::get<0>(operand_and_arg);
159     BlockArgument arg = std::get<1>(operand_and_arg);
160 
161     if (auto operand_sharding = GetXlaShardingFromOperand(operand)) {
162       sharding_for_args.push_back(operand_sharding.getValue());
163       continue;
164     }
165 
166     if (use_spmd) {
167       // If XLA SPMD is enabled, host variables or non-variable per-replica
168       // inputs should take on replicate sharding, unless another sharding is
169       // set via a TPUPartitionedInput op.
170       sharding_for_args.push_back(kReplicateSharding);
171       continue;
172     }
173 
174     auto arg_sharding = GetXlaShardingFromArg(arg);
175     if (arg_sharding) {
176       sharding_for_args.push_back(arg_sharding.getValue());
177       continue;
178     }
179 
180     // Default to maximal sharding core 0 if no sharding is present.
181     sharding_for_args.push_back(logical_core_0_sharding);
182   }
183 }
184 
185 // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via
186 // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func`
187 // result value.
GetXlaShardingFromResult(Value value)188 llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) {
189   if (!value.hasOneUse()) return llvm::None;
190 
191   Operation* user = *value.getUsers().begin();
192   if (auto partitioned_output =
193           llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
194     return partitioned_output._XlaSharding();
195 
196   if (auto assign_var = llvm::dyn_cast<TF::AssignVariableOp>(user))
197     if (auto partitioned_input =
198             llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
199                 assign_var.resource().getDefiningOp()))
200       return partitioned_input._XlaSharding();
201 
202   return llvm::None;
203 }
204 
205 // Returns XLA sharding from XlaSharding op connected to a result value.
206 // XlaSharding op may be direct user of inputs but it may also be followed by an
207 // Identity op and, in the case where bfloat16 type is used, Cast op may be
208 // added right after the input.
209 //
210 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
211 // Case, While) ops and Caller argument values.
212 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
213 // inputs.
GetXlaShardingFromRetval(Value value)214 llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) {
215   llvm::SmallPtrSet<Value, 4> visited_values;
216   Value value_to_visit = value;
217   while (value_to_visit) {
218     if (!visited_values.insert(value_to_visit).second) return llvm::None;
219 
220     Operation* def = value_to_visit.getDefiningOp();
221     if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
222       return sharding._XlaSharding();
223 
224     if (llvm::isa_and_nonnull<TF::IdentityOp, TF::CastOp>(def)) {
225       value_to_visit = def->getOperand(0);
226       continue;
227     }
228 
229     if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
230       FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
231       if (!func) continue;
232       value_to_visit = func.front().getTerminator()->getOperand(
233           value_to_visit.cast<OpResult>().getResultNumber());
234       continue;
235     }
236 
237     break;
238   }
239 
240   return llvm::None;
241 }
242 
243 // Extracts sharding configurations for all outputs by parsing XlaSharding/
244 // TPUPartitionedOutput op connected to the retvals/results.
IdentifyXlaShardingForComputationOutputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_rets)245 void IdentifyXlaShardingForComputationOutputs(
246     StringRef logical_core_0_sharding, bool use_spmd,
247     tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder,
248     llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) {
249   Block& function_block = func.front();
250   Operation* terminator = function_block.getTerminator();
251   sharding_for_rets.reserve(terminator->getNumOperands());
252 
253   // Iterate through results of `cluster_func`. For output ops, look for
254   // TPUPartitionedOutput ops.
255   //
256   // Replicate sharding is used if `use_spmd` is set.
257   //
258   // Iterate through operands of the terminator. If the preceding op is
259   // XlaShardingOp, then the provided sharding configuration is added to the
260   // tf_device.ClusterFunc as an attribute and the function as a result
261   // attribute.
262   for (auto result_and_retval :
263        llvm::zip(cluster_func.results(), terminator->getOpOperands())) {
264     Value result = std::get<0>(result_and_retval);
265     OpOperand& retval = std::get<1>(result_and_retval);
266 
267     if (auto result_sharding = GetXlaShardingFromResult(result)) {
268       sharding_for_rets.push_back(result_sharding.getValue());
269       continue;
270     }
271 
272     if (use_spmd) {
273       // If XLA SPMD is enabled, outputs all should have replicate sharding,
274       // unless another sharding is set via a TPUPartitionedOutput op.
275       sharding_for_rets.push_back(kReplicateSharding);
276       continue;
277     }
278 
279     if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) {
280       sharding_for_rets.push_back(retval_sharding.getValue());
281       continue;
282     }
283 
284     // Default to maximal sharding core 0 if no sharding is present.
285     sharding_for_rets.push_back(logical_core_0_sharding);
286   }
287 }
288 
289 // Extracts input/output sharding configuration of `cluster_func` by parsing
290 // XlaSharding ops inside the `cluster_func`.
IdentifyXlaShardingForTPUComputation(Builder * builder,tf_device::ClusterFuncOp cluster_func)291 void IdentifyXlaShardingForTPUComputation(
292     Builder* builder, tf_device::ClusterFuncOp cluster_func) {
293   // Look up function definition from module.
294   FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
295       cluster_func.func());
296 
297   // By default inputs/outputs have maximal sharding and are assigned to logical
298   // core 0 if no sharding is defined.
299   const std::string logical_core_0_sharding =
300       xla::sharding_builder::AssignDevice(0).SerializeAsString();
301 
302   bool use_spmd = false;
303   if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(kUseSpmdAttr))
304     use_spmd = use_spmd_attr.getValue();
305 
306   llvm::SmallVector<llvm::StringRef, 8> sharding_for_args;
307   IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,
308                                           cluster_func, func, builder,
309                                           sharding_for_args);
310 
311   llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets;
312   IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, use_spmd,
313                                            cluster_func, func, builder,
314                                            sharding_for_rets);
315 
316   auto has_maximal_sharding = [](llvm::StringRef sharding_string) -> bool {
317     xla::OpSharding sharding;
318     sharding.ParseFromString(sharding_string.str());
319     return sharding.type() == xla::OpSharding::MAXIMAL;
320   };
321 
322   // XLA SPMD only supports cases where all inputs/outputs exist on every
323   // partition (sharded or replicated). If any of the inputs/outputs have
324   // maximal sharding, then fallback to MPMD.
325   if (use_spmd && (absl::c_any_of(sharding_for_args, has_maximal_sharding) ||
326                    absl::c_any_of(sharding_for_rets, has_maximal_sharding))) {
327     LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
328                     "exist on every partition (sharded or replicated). If any "
329                     "of the inputs/outputs have maximal sharding, then "
330                     "fallback to MPMD.";
331     sharding_for_args.clear();
332     sharding_for_rets.clear();
333     cluster_func->setAttr(kUseSpmdAttr, builder->getBoolAttr(false));
334     IdentifyXlaShardingForComputationInputs(logical_core_0_sharding,
335                                             /*use_spmd=*/false, cluster_func,
336                                             func, builder, sharding_for_args);
337     IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding,
338                                              /*use_spmd=*/false, cluster_func,
339                                              func, builder, sharding_for_rets);
340   }
341 
342   // Update sharding on function arguments and returns.
343   Block& function_block = func.front();
344   for (auto sharding_and_arg :
345        llvm::zip(sharding_for_args, function_block.getArguments())) {
346     StringRef sharding = std::get<0>(sharding_and_arg);
347     BlockArgument arg = std::get<1>(sharding_and_arg);
348     func.setArgAttr(arg.getArgNumber(), kShardingAttr,
349                     builder->getStringAttr(sharding));
350   }
351 
352   Operation* terminator = function_block.getTerminator();
353   for (auto sharding_and_retval :
354        llvm::zip(sharding_for_rets, terminator->getOpOperands())) {
355     StringRef sharding = std::get<0>(sharding_and_retval);
356     OpOperand& retval = std::get<1>(sharding_and_retval);
357     func.setResultAttr(retval.getOperandNumber(), kShardingAttr,
358                        builder->getStringAttr(sharding));
359   }
360 
361   // Update input/output sharding attributes on tf_device.cluster_func op.
362   cluster_func->setAttr(tensorflow::kInputShardingAttr,
363                         builder->getStrArrayAttr(sharding_for_args));
364   cluster_func->setAttr(tensorflow::kOutputShardingAttr,
365                         builder->getStrArrayAttr(sharding_for_rets));
366 }
367 
runOnOperation()368 void TPUShardingIdentificationPass::runOnOperation() {
369   Builder builder(getOperation().getContext());
370 
371   getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
372     IdentifyXlaShardingForTPUComputation(&builder, cluster_func);
373   });
374 }
375 
376 }  // anonymous namespace
377 
CreateTPUShardingIdentificationPass()378 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
379   return std::make_unique<TPUShardingIdentificationPass>();
380 }
381 
382 static PassRegistration<TPUShardingIdentificationPass> pass;
383 
384 }  // namespace TFTPU
385 }  // namespace mlir
386