• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Block.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
38 #include "tensorflow/compiler/xla/client/sharding_builder.h"
39 
40 namespace mlir {
41 namespace TFTPU {
42 namespace {
43 
44 constexpr char kShardingAttr[] = "mhlo.sharding";
45 constexpr char kReplicateSharding[] = "";
46 
47 struct TPUShardingIdentificationPass
48     : public PassWrapper<TPUShardingIdentificationPass,
49                          OperationPass<ModuleOp>> {
50   void runOnOperation() override;
51 };
52 
53 // Returns XLA sharding from TPUPartitionedInput op connected to a
54 // `tf_device.cluster_func` operand value. If value is a resource type then
55 // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
56 // a `tf_device.cluster_func`.
GetXlaShardingFromOperand(Value value)57 llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) {
58   Value value_to_visit = value;
59   if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
60           value_to_visit.getDefiningOp()))
61     value_to_visit = read_var.resource();
62 
63   if (auto partitioned_input =
64           llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
65               value_to_visit.getDefiningOp()))
66     return partitioned_input._XlaSharding();
67 
68   return llvm::None;
69 }
70 
71 // Returns XLA sharding from a XlaSharding op connected to an argument value. If
72 // value is a resource type then XlaSharding op will be connected to a
73 // ReadVariable op. XlaSharding op may be direct user of inputs but it may also
74 // be followed by an Identity op and, in the case where bfloat16 type is used,
75 // Cast op may be added right after the input.
76 //
77 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
78 // Case, While) ops and Caller return values.
79 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
80 // inputs.
GetXlaShardingFromArg(Value value)81 llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) {
82   llvm::SmallPtrSet<Value, 4> visited_values;
83   llvm::SmallVector<Value, 4> values_to_visit{value};
84   while (!values_to_visit.empty()) {
85     llvm::SmallVector<Value, 4> next_values_to_visit;
86     for (Value value_to_visit : values_to_visit) {
87       if (!visited_values.insert(value_to_visit).second) continue;
88 
89       for (auto& use : value_to_visit.getUses()) {
90         Operation* owner = use.getOwner();
91         if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
92           return sharding._XlaSharding();
93 
94         if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
95           next_values_to_visit.push_back(use.getOwner()->getResult(0));
96           continue;
97         }
98 
99         if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) {
100           FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
101           if (!func) continue;
102           next_values_to_visit.push_back(
103               func.getArgument(use.getOperandNumber()));
104         }
105       }
106     }
107 
108     values_to_visit.swap(next_values_to_visit);
109   }
110 
111   return llvm::None;
112 }
113 
114 // Extracts sharding configurations for all inputs by parsing XlaSharding/
115 // TPUPartitionedInput op connected to the operands/arguments. If argument to
116 // the `cluster_func` directly feeds into another function call op, then
117 // recursively walk the function definition to find the connected XlaSharding
118 // op.
IdentifyXlaShardingForComputationInputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder)119 void IdentifyXlaShardingForComputationInputs(
120     StringRef logical_core_0_sharding, bool use_spmd,
121     tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder) {
122   // Look up function definition from module.
123   Block& function_block = func.front();
124 
125   llvm::SmallVector<llvm::StringRef, 8> sharding_for_args;
126   sharding_for_args.reserve(function_block.getNumArguments());
127 
128   // Iterate through operands of `cluster_func`.
129   // The computation operand can either be:
130   //   1) a TPUPartitionedInput Op if the input has a non-resource type;
131   //   2) a ReadVariableOp else.
132   //
133   // Replicate sharding is used if `use_spmd` is set.
134   //
135   // Iterate through input arguments to the entry block of
136   // tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
137   // XlaSharding ops can:
138   //   1) Directly follow the input argument if input argument has non-resource
139   //      types.
140   //   2) Follow ReadVariableOp if the input type is of resource type.
141   //   3) Follow IdentityOp or CastOp after above cases (1), (2).
142   //
143   // Sharding configurations are added to the tf_device.ClusterFunc as an
144   // attribute and the function as an argument attribute.
145   for (auto operand_and_arg :
146        llvm::zip(cluster_func.operands(), function_block.getArguments())) {
147     Value operand = std::get<0>(operand_and_arg);
148     BlockArgument arg = std::get<1>(operand_and_arg);
149     const int index = arg.getArgNumber();
150 
151     if (auto operand_sharding = GetXlaShardingFromOperand(operand)) {
152       sharding_for_args.push_back(operand_sharding.getValue());
153       func.setArgAttr(index, kShardingAttr,
154                       builder->getStringAttr(operand_sharding.getValue()));
155       continue;
156     }
157 
158     if (use_spmd) {
159       // If XLA SPMD is enabled, host variables or non-variable per-replica
160       // inputs should take on replicate sharding, unless another sharding is
161       // set via a TPUPartitionedInput op.
162       sharding_for_args.push_back(kReplicateSharding);
163       func.setArgAttr(index, kShardingAttr,
164                       builder->getStringAttr(kReplicateSharding));
165       continue;
166     }
167 
168     auto arg_sharding = GetXlaShardingFromArg(arg);
169     if (arg_sharding) {
170       sharding_for_args.push_back(arg_sharding.getValue());
171       func.setArgAttr(index, kShardingAttr,
172                       builder->getStringAttr(arg_sharding.getValue()));
173       continue;
174     }
175 
176     // Default to maximal sharding core 0 if no sharding is present.
177     sharding_for_args.push_back(logical_core_0_sharding);
178     func.setArgAttr(index, kShardingAttr,
179                     builder->getStringAttr(logical_core_0_sharding));
180   }
181 
182   cluster_func->setAttr(tensorflow::kInputShardingAttr,
183                         builder->getStrArrayAttr(sharding_for_args));
184 }
185 
186 // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via
187 // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func`
188 // result value.
GetXlaShardingFromResult(Value value)189 llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) {
190   if (!value.hasOneUse()) return llvm::None;
191 
192   Operation* user = *value.getUsers().begin();
193   if (auto partitioned_output =
194           llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
195     return partitioned_output._XlaSharding();
196 
197   if (auto assign_var = llvm::dyn_cast<TF::AssignVariableOp>(user))
198     if (auto partitioned_input =
199             llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
200                 assign_var.resource().getDefiningOp()))
201       return partitioned_input._XlaSharding();
202 
203   return llvm::None;
204 }
205 
206 // Returns XLA sharding from XlaSharding op connected to a result value.
207 // XlaSharding op may be direct user of inputs but it may also be followed by an
208 // Identity op and, in the case where bfloat16 type is used, Cast op may be
209 // added right after the input.
210 //
211 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
212 // Case, While) ops and Caller argument values.
213 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
214 // inputs.
GetXlaShardingFromRetval(Value value)215 llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) {
216   llvm::SmallPtrSet<Value, 4> visited_values;
217   Value value_to_visit = value;
218   while (value_to_visit) {
219     if (!visited_values.insert(value_to_visit).second) return llvm::None;
220 
221     Operation* def = value_to_visit.getDefiningOp();
222     if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
223       return sharding._XlaSharding();
224 
225     if (llvm::isa_and_nonnull<TF::IdentityOp, TF::CastOp>(def)) {
226       value_to_visit = def->getOperand(0);
227       continue;
228     }
229 
230     if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
231       FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
232       if (!func) continue;
233       value_to_visit = func.front().getTerminator()->getOperand(
234           value_to_visit.cast<OpResult>().getResultNumber());
235       continue;
236     }
237 
238     break;
239   }
240 
241   return llvm::None;
242 }
243 
244 // Extracts sharding configurations for all outputs by parsing XlaSharding/
245 // TPUPartitionedOutput op connected to the retvals/results.
IdentifyXlaShardingForComputationOutputs(StringRef logical_core_0_sharding,bool use_spmd,tf_device::ClusterFuncOp cluster_func,FuncOp func,Builder * builder)246 void IdentifyXlaShardingForComputationOutputs(
247     StringRef logical_core_0_sharding, bool use_spmd,
248     tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder) {
249   Block& function_block = func.front();
250   Operation* terminator = function_block.getTerminator();
251   llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets;
252   sharding_for_rets.reserve(terminator->getNumOperands());
253 
254   // Iterate through results of `cluster_func`. For output ops, look for
255   // TPUPartitionedOutput ops.
256   //
257   // Replicate sharding is used if `use_spmd` is set.
258   //
259   // Iterate through operands of the terminator. If the preceding op is
260   // XlaShardingOp, then the provided sharding configuration is added to the
261   // tf_device.ClusterFunc as an attribute and the function as a result
262   // attribute.
263   for (auto result_and_retval :
264        llvm::zip(cluster_func.results(), terminator->getOpOperands())) {
265     Value result = std::get<0>(result_and_retval);
266     OpOperand& retval = std::get<1>(result_and_retval);
267     const int index = retval.getOperandNumber();
268 
269     if (auto result_sharding = GetXlaShardingFromResult(result)) {
270       sharding_for_rets.push_back(result_sharding.getValue());
271       func.setResultAttr(index, kShardingAttr,
272                          builder->getStringAttr(result_sharding.getValue()));
273       continue;
274     }
275 
276     if (use_spmd) {
277       // If XLA SPMD is enabled, outputs all should have replicate sharding,
278       // unless another sharding is set via a TPUPartitionedOutput op.
279       sharding_for_rets.push_back(kReplicateSharding);
280       func.setResultAttr(index, kShardingAttr,
281                          builder->getStringAttr(kReplicateSharding));
282       continue;
283     }
284 
285     if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) {
286       sharding_for_rets.push_back(retval_sharding.getValue());
287       func.setResultAttr(index, kShardingAttr,
288                          builder->getStringAttr(retval_sharding.getValue()));
289       continue;
290     }
291 
292     // Default to maximal sharding core 0 if no sharding is present.
293     sharding_for_rets.push_back(logical_core_0_sharding);
294     func.setResultAttr(index, kShardingAttr,
295                        builder->getStringAttr(logical_core_0_sharding));
296   }
297 
298   cluster_func->setAttr(tensorflow::kOutputShardingAttr,
299                         builder->getStrArrayAttr(sharding_for_rets));
300 }
301 
302 // Extracts input/output sharding configuration of `cluster_func` by parsing
303 // XlaSharding ops inside the `cluster_func`.
IdentifyXlaShardingForTPUComputation(Builder * builder,tf_device::ClusterFuncOp cluster_func)304 void IdentifyXlaShardingForTPUComputation(
305     Builder* builder, tf_device::ClusterFuncOp cluster_func) {
306   // Look up function definition from module.
307   FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
308       cluster_func.func());
309 
310   // By default inputs/outputs have maximal sharding and are assigned to logical
311   // core 0 if no sharding is defined.
312   const std::string logical_core_0_sharding =
313       xla::sharding_builder::AssignDevice(0).SerializeAsString();
314 
315   bool use_spmd = false;
316   if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
317           "use_spmd_for_xla_partitioning"))
318     use_spmd = use_spmd_attr.getValue();
319 
320   IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,
321                                           cluster_func, func, builder);
322 
323   IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, use_spmd,
324                                            cluster_func, func, builder);
325 }
326 
runOnOperation()327 void TPUShardingIdentificationPass::runOnOperation() {
328   Builder builder(getOperation().getContext());
329 
330   getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
331     IdentifyXlaShardingForTPUComputation(&builder, cluster_func);
332   });
333 }
334 
335 }  // anonymous namespace
336 
CreateTPUShardingIdentificationPass()337 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
338   return std::make_unique<TPUShardingIdentificationPass>();
339 }
340 
341 static PassRegistration<TPUShardingIdentificationPass> pass(
342     "tf-tpu-sharding-identification",
343     "Identifies and handles inputs/outputs of TPU computation that is "
344     "sharded across logical cores.");
345 
346 }  // namespace TFTPU
347 }  // namespace mlir
348