• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Block.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
41 #include "tensorflow/compiler/xla/client/sharding_builder.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 
44 namespace mlir {
45 namespace TFTPU {
46 namespace {
47 
48 constexpr char kReplicateSharding[] = "";
49 constexpr char kShardingAttr[] = "mhlo.sharding";
50 constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning";
51 constexpr char kAliasingAttr[] = "tf.aliasing_output";
52 
53 struct TPUShardingIdentificationPass
54     : public TF::TPUShardingIdentificationPassBase<
55           TPUShardingIdentificationPass> {
56   void runOnOperation() final;
57 };
58 
59 // Returns XLA sharding from TPUPartitionedInput op connected to a
60 // `tf_device.cluster_func` operand value. If value is a resource type then
61 // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
62 // a `tf_device.cluster_func`.
GetXlaShardingFromOperand(Value value)63 llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) {
64   Value value_to_visit = value;
65   if (auto read_var = value_to_visit.getDefiningOp<TF::ReadVariableOp>())
66     value_to_visit = read_var.resource();
67 
68   if (auto partitioned_input =
69           value_to_visit.getDefiningOp<TF::TPUPartitionedInputOp>())
70     return partitioned_input._XlaSharding();
71 
72   return llvm::None;
73 }
74 
75 // Given a `tf_device.cluster_func` operand value return true iff it a device
76 // variable that should default to MAXIMAL sharding. Device variables that are
77 // per-replica or distributed default to MAXIMAL sharding, which corresponds to
78 // arguments of the `tf_device.replicate`. Otherwise the variable is broadcast,
79 // which corresponds to edges that are implicitly captured by the `replicate`.
IsMaximalVariable(Value value)80 bool IsMaximalVariable(Value value) {
81   auto read_var = value.getDefiningOp<TF::ReadVariableOp>();
82   return read_var && read_var->getParentOfType<tf_device::ReplicateOp>();
83 }
84 
85 // Verify whether the given sharding can be applied to the given (tensor) type.
86 // (A bad sharding might mean failing tf.Split ops if the graph later executes
87 //  on CPU)
88 // If the sharding is incorrect, return failure. If it's good, or if we can't
89 // verify it, return success.
VerifySharding(Type type,StringRef sharding_string)90 LogicalResult VerifySharding(Type type, StringRef sharding_string) {
91   xla::OpSharding sharding;
92   if (!sharding.ParseFromString(sharding_string.str())) {
93     // Some test cases use \01\02\03 as sharding, to test propagation. Treat
94     // a non-proto sharding as valid, and don't verify further.
95     return success();
96   }
97   if (sharding.type() != xla::OpSharding::OTHER) {
98     // We currently only verify shardings that actually break a tensor apart.
99     return success();
100   }
101   if (RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>()) {
102     if (ranked_type.getRank() < sharding.tile_assignment_dimensions_size()) {
103       return failure();
104     }
105   }
106   return success();
107 }
108 
109 // Verify sharding for all arguments and return values.
VerifyShardings(mlir::func::FuncOp func,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_rets)110 LogicalResult VerifyShardings(
111     mlir::func::FuncOp func,
112     const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args,
113     const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) {
114   Block& function_block = func.front();
115   for (auto sharding_and_arg :
116        llvm::zip(sharding_for_args, function_block.getArguments())) {
117     StringRef sharding = std::get<0>(sharding_and_arg);
118     BlockArgument arg = std::get<1>(sharding_and_arg);
119     if (failed(VerifySharding(arg.getType(), sharding))) return failure();
120   }
121   Operation* terminator = function_block.getTerminator();
122   for (auto sharding_and_retval :
123        llvm::zip(sharding_for_rets, terminator->getOpOperands())) {
124     StringRef sharding = std::get<0>(sharding_and_retval);
125     OpOperand& retval = std::get<1>(sharding_and_retval);
126     if (failed(VerifySharding(retval.get().getType(), sharding)))
127       return failure();
128   }
129   return success();
130 }
131 
132 // Returns XLA sharding from a XlaSharding op connected to an argument value. If
133 // value is a resource type then XlaSharding op will be connected to a
134 // ReadVariable op. XlaSharding op may be direct user of inputs but it may also
135 // be followed by an Identity op and, in the case where bfloat16 type is used,
136 // Cast op may be added right after the input.
137 //
138 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
139 // Case, While) ops and Caller return values.
140 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
141 // inputs.
GetXlaShardingFromArg(Value value)142 llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) {
143   llvm::SmallPtrSet<Value, 4> visited_values;
144   llvm::SmallVector<Value, 4> values_to_visit{value};
145   while (!values_to_visit.empty()) {
146     llvm::SmallVector<Value, 4> next_values_to_visit;
147     for (Value value_to_visit : values_to_visit) {
148       if (!visited_values.insert(value_to_visit).second) continue;
149 
150       for (auto& use : value_to_visit.getUses()) {
151         Operation* owner = use.getOwner();
152         if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
153           return sharding._XlaSharding();
154 
155         if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
156           next_values_to_visit.push_back(use.getOwner()->getResult(0));
157           continue;
158         }
159 
160         if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) {
161           func::FuncOp func =
162               llvm::dyn_cast<func::FuncOp>(call_op.resolveCallable());
163           if (!func) continue;
164           next_values_to_visit.push_back(
165               func.getArgument(use.getOperandNumber()));
166         }
167       }
168     }
169 
170     values_to_visit.swap(next_values_to_visit);
171   }
172 
173   return llvm::None;
174 }
175 
176 // Extracts sharding configurations for all inputs by parsing XlaSharding/
177 // TPUPartitionedInput op connected to the operands/arguments. If argument to
178 // the `cluster_func` directly feeds into another function call op, then
179 // recursively walk the function definition to find the connected XlaSharding
180 // op.
IdentifyXlaShardingForComputationInputs(StringRef logical_core_0_sharding,bool use_spmd,bool infer_from_computation,tf_device::ClusterFuncOp cluster_func,func::FuncOp func,Builder * builder,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args)181 void IdentifyXlaShardingForComputationInputs(
182     StringRef logical_core_0_sharding, bool use_spmd,
183     bool infer_from_computation, tf_device::ClusterFuncOp cluster_func,
184     func::FuncOp func, Builder* builder,
185     llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) {
186   // Look up function definition from module.
187   Block& function_block = func.front();
188 
189   sharding_for_args.reserve(function_block.getNumArguments());
190 
191   // Iterate through operands of `cluster_func`.
192   // The computation operand can either be:
193   //   1) a TPUPartitionedInput Op if the input has a non-resource type;
194   //   2) a ReadVariableOp else.
195   //
196   // Replicate sharding is used if `use_spmd` is set.
197   //
198   // Iterate through input arguments to the entry block of
199   // tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
200   // XlaSharding ops can:
201   //   1) Directly follow the input argument if input argument has non-resource
202   //      types.
203   //   2) Follow ReadVariableOp if the input type is of resource type.
204   //   3) Follow IdentityOp or CastOp after above cases (1), (2).
205   //
206   // Sharding configurations are added to the tf_device.ClusterFunc as an
207   // attribute and the function as an argument attribute.
208   for (auto operand_and_arg :
209        llvm::zip(cluster_func.operands(), function_block.getArguments())) {
210     Value operand = std::get<0>(operand_and_arg);
211     BlockArgument arg = std::get<1>(operand_and_arg);
212 
213     if (auto operand_sharding = GetXlaShardingFromOperand(operand)) {
214       sharding_for_args.push_back(operand_sharding.getValue());
215       continue;
216     }
217 
218     if (infer_from_computation) {
219       auto arg_sharding = GetXlaShardingFromArg(arg);
220       if (arg_sharding) {
221         sharding_for_args.push_back(arg_sharding.getValue());
222         continue;
223       }
224     }
225 
226     if (use_spmd && !IsMaximalVariable(operand)) {
227       // If XLA SPMD is enabled, host variables or non-variable per-replica
228       // inputs should take on replicate sharding, so that every device gets the
229       // whole tensor(s) (and can slice them up later). Exclude device
230       // variables, which always should take maximal sharding.
231       sharding_for_args.push_back(kReplicateSharding);
232       continue;
233     }
234 
235     // Otherwise, default to maximal sharding core 0.
236     sharding_for_args.push_back(logical_core_0_sharding);
237   }
238 }
239 
240 // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via
241 // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func`
242 // result value.
GetXlaShardingFromResult(Value value)243 llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) {
244   if (!value.hasOneUse()) return llvm::None;
245 
246   Operation* user = *value.getUsers().begin();
247   if (auto partitioned_output =
248           llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
249     return partitioned_output._XlaSharding();
250 
251   if (auto assign_var = llvm::dyn_cast<TF::AssignVariableOp>(user))
252     if (auto partitioned_input =
253             assign_var.resource().getDefiningOp<TF::TPUPartitionedInputOp>())
254       return partitioned_input._XlaSharding();
255 
256   return llvm::None;
257 }
258 
259 // Looks up arg->retval aliases for every argument, and builds a reverse map.
ExtractAliases(func::FuncOp func,llvm::SmallVectorImpl<int> & aliases)260 void ExtractAliases(func::FuncOp func, llvm::SmallVectorImpl<int>& aliases) {
261   aliases.resize(func.getNumResults(), -1);
262   for (int i = 0; i < func.getNumArguments(); i++) {
263     if (auto v = func.getArgAttrOfType<mlir::IntegerAttr>(i, kAliasingAttr)) {
264       int retval_index = v.getInt();
265       if (retval_index >= 0 && retval_index < aliases.size()) {
266         aliases[retval_index] = i;
267       }
268     }
269   }
270 }
271 
272 // Returns XLA sharding from argument connected via tf.aliasing_output.
GetXlaShardingFromAlias(Value value,llvm::SmallVectorImpl<int> & aliases,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args)273 llvm::Optional<StringRef> GetXlaShardingFromAlias(
274     Value value, llvm::SmallVectorImpl<int>& aliases,
275     const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) {
276   int retval_index = value.cast<OpResult>().getResultNumber();
277   if (retval_index >= 0 && retval_index < aliases.size()) {
278     int arg_index = aliases[retval_index];
279     if (arg_index >= 0 && arg_index < sharding_for_args.size()) {
280       return sharding_for_args[arg_index];
281     }
282   }
283   return llvm::None;
284 }
285 
286 // Returns XLA sharding from XlaSharding op connected to a result value.
287 // XlaSharding op may be directly connected to output but it may also be
288 // followed by Identity or simple arithmetic ops. In case where bfloat16 type is
289 // used, we might see a Cast op.
290 //
291 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
292 // Case, While) ops and Caller argument values.
293 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
294 // inputs.
GetXlaShardingFromRetval(Value value)295 llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) {
296   llvm::SmallPtrSet<Value, 4> visited_values;
297   llvm::SmallVector<Value, 4> values_to_visit;
298   values_to_visit.push_back(value);
299 
300   while (!values_to_visit.empty()) {
301     Value value_to_visit = values_to_visit.pop_back_val();
302 
303     if (!visited_values.insert(value_to_visit).second) {
304       continue;
305     }
306 
307     Operation* def = value_to_visit.getDefiningOp();
308     if (!def) {
309       continue;
310     }
311 
312     if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
313       return sharding._XlaSharding();
314 
315     if (auto sharding = def->getAttrOfType<StringAttr>("_XlaSharding")) {
316       return sharding.strref();
317     }
318 
319     if (  // Cast, real/imag, etc.
320         def->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
321         // Exp, ceil, etc.
322         def->hasTrait<mlir::OpTrait::SameOperandsAndResultType>() ||
323         // Identity
324         def->hasTrait<mlir::OpTrait::TF::OperandsSameAsResultsTypeOrRef>() ||
325         // AddV2, Sub, etc.
326         (def->hasTrait<
327              mlir::OpTrait::TF::SameOperandsAndResultElementTypeResolveRef>() &&
328          def->hasTrait<mlir::OpTrait::TF::CwiseBinary>())) {
329       for (auto operand : def->getOperands()) {
330         values_to_visit.push_back(operand);
331       }
332       continue;
333     }
334 
335     if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
336       func::FuncOp func =
337           llvm::dyn_cast<func::FuncOp>(call_op.resolveCallable());
338       if (!func) continue;
339       value_to_visit = func.front().getTerminator()->getOperand(
340           value_to_visit.cast<OpResult>().getResultNumber());
341       values_to_visit.push_back(value_to_visit);
342       continue;
343     }
344   }
345 
346   return llvm::None;
347 }
348 
349 // Extracts sharding configurations for all outputs by parsing XlaSharding/
350 // TPUPartitionedOutput op connected to the retvals/results.
IdentifyXlaShardingForComputationOutputs(StringRef logical_core_0_sharding,bool use_spmd,bool infer_from_computation,tf_device::ClusterFuncOp cluster_func,func::FuncOp func,Builder * builder,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_rets)351 void IdentifyXlaShardingForComputationOutputs(
352     StringRef logical_core_0_sharding, bool use_spmd,
353     bool infer_from_computation, tf_device::ClusterFuncOp cluster_func,
354     func::FuncOp func, Builder* builder,
355     const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args,
356     llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) {
357   Block& function_block = func.front();
358   Operation* terminator = function_block.getTerminator();
359   sharding_for_rets.reserve(terminator->getNumOperands());
360 
361   llvm::SmallVector<int, 8> aliases;  // maps return value index to arg index
362   ExtractAliases(func, aliases);
363 
364   // Iterate through results of `cluster_func`. For output ops, look for
365   // TPUPartitionedOutput ops.
366   //
367   // Replicate sharding is used if `use_spmd` is set.
368   //
369   // Iterate through operands of the terminator. If the preceding op is
370   // XlaShardingOp, then the provided sharding configuration is added to the
371   // tf_device.ClusterFunc as an attribute and the function as a result
372   // attribute.
373   for (auto result_and_retval :
374        llvm::zip(cluster_func.results(), terminator->getOpOperands())) {
375     Value result = std::get<0>(result_and_retval);
376     OpOperand& retval = std::get<1>(result_and_retval);
377 
378     if (auto result_sharding = GetXlaShardingFromResult(result)) {
379       sharding_for_rets.push_back(result_sharding.getValue());
380       continue;
381     }
382 
383     if (auto from_alias =
384             GetXlaShardingFromAlias(result, aliases, sharding_for_args)) {
385       sharding_for_rets.push_back(from_alias.getValue());
386       continue;
387     }
388 
389     if (infer_from_computation) {
390       if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) {
391         sharding_for_rets.push_back(retval_sharding.getValue());
392         continue;
393       }
394     }
395 
396     if (use_spmd) {
397       // If XLA SPMD is enabled, we default to replicate sharding. This way,
398       // all devices get the whole tensor(s), but if there's an XlaSharding op
399       // deeper in the function, they can use dynamic-slice to slice off their
400       // part of the computation.
401       sharding_for_rets.push_back(kReplicateSharding);
402       continue;
403     }
404 
405     // Otherwise, default to maximal sharding core 0.
406     sharding_for_rets.push_back(logical_core_0_sharding);
407   }
408 }
409 
410 // Extracts input/output sharding configuration of `cluster_func` by parsing
411 // XlaSharding ops inside the `cluster_func`.
IdentifyXlaShardingForTPUComputation(Builder * builder,tf_device::ClusterFuncOp cluster_func)412 void IdentifyXlaShardingForTPUComputation(
413     Builder* builder, tf_device::ClusterFuncOp cluster_func) {
414   // Look up function definition from module.
415   func::FuncOp func =
416       cluster_func->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(
417           cluster_func.func());
418 
419   // By default inputs/outputs have maximal sharding and are assigned to logical
420   // core 0 if no sharding is defined.
421   const std::string logical_core_0_sharding =
422       xla::sharding_builder::AssignDevice(0).SerializeAsString();
423 
424   bool use_spmd = false;
425   if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(kUseSpmdAttr))
426     use_spmd = use_spmd_attr.getValue();
427 
428   llvm::SmallVector<llvm::StringRef, 8> sharding_for_args;
429   IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,
430                                           /*infer_from_computation=*/true,
431                                           cluster_func, func, builder,
432                                           sharding_for_args);
433 
434   llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets;
435   IdentifyXlaShardingForComputationOutputs(
436       logical_core_0_sharding, use_spmd, /*infer_from_computation=*/true,
437       cluster_func, func, builder, sharding_for_args, sharding_for_rets);
438 
439   auto has_maximal_sharding = [](llvm::StringRef sharding_string) -> bool {
440     xla::OpSharding sharding;
441     sharding.ParseFromString(sharding_string.str());
442     return sharding.type() == xla::OpSharding::MAXIMAL;
443   };
444 
445   // XLA SPMD only supports cases where all inputs/outputs exist on every
446   // partition (sharded or replicated). If any of the inputs/outputs have
447   // maximal sharding, then fallback to MPMD. Also fall back if any of the
448   // shardings aren't compatible with the rank of their tensor.
449   if ((use_spmd && (absl::c_any_of(sharding_for_args, has_maximal_sharding) ||
450                     absl::c_any_of(sharding_for_rets, has_maximal_sharding))) ||
451       failed(VerifyShardings(func, sharding_for_args, sharding_for_rets))) {
452     LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
453                     "exist on every partition (sharded or replicated). If any "
454                     "of the inputs/outputs have maximal sharding, then "
455                     "fallback to MPMD.";
456     sharding_for_args.clear();
457     sharding_for_rets.clear();
458     cluster_func->setAttr(kUseSpmdAttr, builder->getBoolAttr(false));
459 
460     IdentifyXlaShardingForComputationInputs(
461         logical_core_0_sharding,
462         /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func,
463         func, builder, sharding_for_args);
464     IdentifyXlaShardingForComputationOutputs(
465         logical_core_0_sharding,
466         /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func,
467         func, builder, sharding_for_args, sharding_for_rets);
468   }
469 
470   // Update sharding on function arguments and returns.
471   Block& function_block = func.front();
472   for (auto sharding_and_arg :
473        llvm::zip(sharding_for_args, function_block.getArguments())) {
474     StringRef sharding = std::get<0>(sharding_and_arg);
475     BlockArgument arg = std::get<1>(sharding_and_arg);
476     func.setArgAttr(arg.getArgNumber(), kShardingAttr,
477                     builder->getStringAttr(sharding));
478   }
479 
480   Operation* terminator = function_block.getTerminator();
481   for (auto sharding_and_retval :
482        llvm::zip(sharding_for_rets, terminator->getOpOperands())) {
483     StringRef sharding = std::get<0>(sharding_and_retval);
484     OpOperand& retval = std::get<1>(sharding_and_retval);
485     func.setResultAttr(retval.getOperandNumber(), kShardingAttr,
486                        builder->getStringAttr(sharding));
487   }
488 
489   // Update input/output sharding attributes on tf_device.cluster_func op.
490   cluster_func->setAttr(tensorflow::kInputShardingAttr,
491                         builder->getStrArrayAttr(sharding_for_args));
492   cluster_func->setAttr(tensorflow::kOutputShardingAttr,
493                         builder->getStrArrayAttr(sharding_for_rets));
494 }
495 
runOnOperation()496 void TPUShardingIdentificationPass::runOnOperation() {
497   Builder builder(getOperation().getContext());
498 
499   getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
500     IdentifyXlaShardingForTPUComputation(&builder, cluster_func);
501   });
502 }
503 
504 }  // anonymous namespace
505 
CreateTPUShardingIdentificationPass()506 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
507   return std::make_unique<TPUShardingIdentificationPass>();
508 }
509 
510 }  // namespace TFTPU
511 }  // namespace mlir
512