• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/Types.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 
38 namespace tensorflow {
39 
40 const char* const kInputShardingAttr = "input_sharding_configuration";
41 const char* const kOutputShardingAttr = "output_sharding_configuration";
42 
43 namespace {
44 
45 constexpr char kNumSplitAttr[] = "num_split";
46 
47 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
48 // in 'split_dimension' dimension and returns the split values.
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location & location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)49 mlir::LogicalResult CreateSplitOp(const int num_split,
50                                   const int split_dimension,
51                                   const mlir::Location& location,
52                                   mlir::Value src_input,
53                                   mlir::OpBuilder* builder,
54                                   mlir::TF::SplitOp* split_op) {
55   // Creates a const op to hold split dimension value.
56   auto split_dim_type =
57       mlir::RankedTensorType::get({}, builder->getIntegerType(32));
58   auto split_dimension_attr =
59       mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
60   auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
61       location, split_dim_type, split_dimension_attr);
62 
63   // Correctly set output shapes of split op output if input shape is statically
64   // known.
65   mlir::Type output_type;
66   auto input_type = src_input.getType().cast<mlir::TensorType>();
67 
68   if (input_type.hasRank()) {
69     if (input_type.getShape()[split_dimension] ==
70         mlir::ShapedType::kDynamicSize) {
71       output_type = input_type;
72     } else {
73       auto shape = llvm::to_vector<4>(input_type.getShape());
74       if (shape[split_dimension] % num_split != 0) {
75         return mlir::emitError(
76             location,
77             llvm::formatv(
78                 "incorrect input sharding configuration received. "
79                 "{0}-th dimension of the input must be evenly divisible by {1}",
80                 split_dimension, num_split));
81       }
82 
83       shape[split_dimension] = shape[split_dimension] / num_split;
84       output_type =
85           mlir::RankedTensorType::get(shape, input_type.getElementType());
86     }
87   } else {
88     output_type = input_type;
89   }
90 
91   // Creates a split op that splits |src_input| along |split_dimension|.
92   llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
93   *split_op = builder->create<mlir::TF::SplitOp>(
94       location, output_types, split_dimension_op.output(), src_input);
95   (*split_op)->setAttr(
96       kNumSplitAttr,
97       builder->getIntegerAttr(builder->getIntegerType(32), num_split));
98   return mlir::success();
99 }
100 
101 // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`.
CreateConcatOp(const int concat_dimension,const mlir::Location & location,mlir::ArrayRef<mlir::Value> inputs,mlir::OpBuilder * builder)102 mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension,
103                                   const mlir::Location& location,
104                                   mlir::ArrayRef<mlir::Value> inputs,
105                                   mlir::OpBuilder* builder) {
106   // Creates a const op to hold concat dimension value.
107   auto concat_dim_type =
108       mlir::RankedTensorType::get({}, builder->getIntegerType(32));
109   auto concat_dimension_attr =
110       mlir::DenseElementsAttr::get(concat_dim_type, concat_dimension);
111   auto concat_dimension_op = builder->create<mlir::TF::ConstOp>(
112       location, concat_dim_type, concat_dimension_attr);
113 
114   // Correctly set output shapes of concat op output if output shape is
115   // statically known. Since the shape of TPUExecute op must be the same
116   // across logical devices, we refer to the shape of 0th logical device
117   // computation output.
118   mlir::Type output_type;
119   auto input_type = inputs[0].getType().cast<mlir::TensorType>();
120 
121   if (input_type.hasRank()) {
122     if (input_type.getShape()[concat_dimension] ==
123         mlir::ShapedType::kDynamicSize) {
124       output_type = input_type;
125     } else {
126       auto shape = llvm::to_vector<4>(input_type.getShape());
127       shape[concat_dimension] = shape[concat_dimension] * inputs.size();
128       output_type =
129           mlir::RankedTensorType::get(shape, input_type.getElementType());
130     }
131   } else {
132     output_type = input_type;
133   }
134 
135   return builder->create<mlir::TF::ConcatOp>(
136       location, output_type, concat_dimension_op.output(), inputs);
137 }
138 
139 // For tile sharded inputs to TPU computation, inject split op between the
140 // input values and TPU computation so that tiled input values are passed in
141 // as inputs to TPU computations. If more than one dimension is sharded, then
142 // a tree of connected split ops are added before tf_device.parallel_execute op.
HandleTileShardedInputs(const mlir::Location & location,const xla::OpSharding & input_sharding,const mlir::Value & original_source,mlir::OpBuilder * builder,llvm::SmallVectorImpl<mlir::Value> * tiled_inputs)143 mlir::LogicalResult HandleTileShardedInputs(
144     const mlir::Location& location, const xla::OpSharding& input_sharding,
145     const mlir::Value& original_source, mlir::OpBuilder* builder,
146     llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
147   llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
148   split_ops_for_tiled_input.reserve(
149       input_sharding.tile_assignment_devices_size());
150 
151   // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
152   // are created such that input data is sharded in row major order.
153   // Split nodes at ith depth from the original input node represent nodes
154   // that split the input data at i-th dimension.
155   const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
156   for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
157     const int num_splits = num_splits_and_index.value();
158     const int dimension_index = num_splits_and_index.index();
159     if (num_splits == 1) continue;
160 
161     // Creates root split op.
162     if (split_ops_for_tiled_input.empty()) {
163       mlir::TF::SplitOp root_split_op;
164       auto result = CreateSplitOp(num_splits, dimension_index, location,
165                                   original_source, builder, &root_split_op);
166       if (mlir::failed(result)) return mlir::failure();
167 
168       split_ops_for_tiled_input.emplace_back(root_split_op);
169       continue;
170     }
171 
172     llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
173     new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
174 
175     for (auto split_op : split_ops_for_tiled_input) {
176       for (auto parent_split_output_value : split_op.getResults()) {
177         mlir::TF::SplitOp child_split_op;
178         auto result =
179             CreateSplitOp(num_splits, dimension_index, location,
180                           parent_split_output_value, builder, &child_split_op);
181         if (mlir::failed(result)) return mlir::failure();
182 
183         new_split_ops.emplace_back(child_split_op);
184       }
185     }
186 
187     std::swap(new_split_ops, split_ops_for_tiled_input);
188   }
189 
190   // `split_ops_for_tiled_input` now includes final split nodes
191   // from which sharded data will be fed into TPUExcute ops -- sorted by
192   // row major order.
193   tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
194   for (auto split_op : split_ops_for_tiled_input)
195     tiled_inputs->append(split_op.getResults().begin(),
196                          split_op.getResults().end());
197 
198   return mlir::success();
199 }
200 
UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding)201 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) {
202   return sharding != xla::OpSharding::REPLICATED &&
203          sharding != xla::OpSharding::OTHER;
204 }
205 
206 }  // namespace
207 
ExtractInputsForLogicalDevices(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::OpBuilder * builder,llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value,4>> * input_list)208 mlir::LogicalResult ExtractInputsForLogicalDevices(
209     const int num_cores_per_replica,
210     mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder,
211     llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
212   // Initialize the input list for each logical devices.
213   input_list->reserve(num_cores_per_replica);
214   for (int i = 0; i < num_cores_per_replica; ++i)
215     input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
216 
217   llvm::SmallVector<mlir::Value, 4> cluster_func_inputs(
218       cluster_func.getOperands());
219   auto sharding_attrs =
220       cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
221           kInputShardingAttr);
222   // If sharding attribute does not exist, then all inputs are placed on 0th
223   // logical core by default.
224   if (!sharding_attrs) {
225     (*input_list)[0] = cluster_func_inputs;
226     return mlir::success();
227   }
228 
229   // Enumerate sharding configuration for each inputs. If input has replicate
230   // sharding then all logical devices take the value as input. If input has
231   // maximal sharding then only the specified logical device take the value as
232   // the input.
233   for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
234     const auto& sharding_attr = sharding_attr_and_index.value();
235     const auto input_index = sharding_attr_and_index.index();
236     const auto& input_value = cluster_func_inputs[input_index];
237 
238     xla::OpSharding sharding;
239     sharding.ParseFromString(
240         sharding_attr.cast<mlir::StringAttr>().getValue().str());
241 
242     const auto input_sharding_type = sharding.type();
243 
244     auto tiled_sharding_mismatched = [&](int tiled_input_size) {
245       return cluster_func.emitError(
246           llvm::formatv("incorrect {0}-th tiled input sharding received. "
247                         "Product of tile sharding splits({1}) must be equal to "
248                         "number of logical devices : {2}",
249                         input_index, tiled_input_size, num_cores_per_replica));
250     };
251 
252     // If input is already partitioned using the `tf.TPUPartitionedInput` op,
253     // only replicated sharding is supported where i-th operand to
254     // `tf.TPUPartitionedInput` op is input to the i-th logical device.
255     if (auto partitioned_input =
256             llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedInputOp>(
257                 input_value.getDefiningOp())) {
258       if (UnsupportedPartitionedShardingType(input_sharding_type))
259         return cluster_func->emitOpError()
260                << "unsupported input sharding type "
261                << OpSharding_Type_Name(input_sharding_type) << " for "
262                << input_index << "-th input";
263 
264       if (input_sharding_type == xla::OpSharding::REPLICATED) {
265         for (auto& index_and_inputs : llvm::enumerate(*input_list)) {
266           index_and_inputs.value().emplace_back(
267               partitioned_input.getOperand(index_and_inputs.index()));
268         }
269       } else {
270         assert(input_sharding_type == xla::OpSharding::OTHER);
271         if (partitioned_input.inputs().size() != num_cores_per_replica)
272           return tiled_sharding_mismatched(partitioned_input.inputs().size());
273 
274         for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
275           const int assigned_logical_device =
276               sharding.tile_assignment_devices(i);
277           (*input_list)[assigned_logical_device].emplace_back(
278               partitioned_input.inputs()[i]);
279         }
280       }
281       continue;
282     }
283 
284     if (input_sharding_type == xla::OpSharding::OTHER) {
285       llvm::SmallVector<mlir::Value, 4> tiled_inputs;
286       auto result = HandleTileShardedInputs(
287           cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
288       if (mlir::failed(result)) return mlir::failure();
289 
290       const int64_t tiled_inputs_size = tiled_inputs.size();
291       if (tiled_inputs_size != num_cores_per_replica)
292         return tiled_sharding_mismatched(tiled_inputs.size());
293 
294       for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
295         const int assigned_logical_device = sharding.tile_assignment_devices(i);
296         (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
297       }
298     } else if (input_sharding_type == xla::OpSharding::REPLICATED) {
299       for (auto& inputs : *input_list) inputs.emplace_back(input_value);
300     } else {
301       assert(input_sharding_type == xla::OpSharding::MAXIMAL);
302       const int logical_device_id = sharding.tile_assignment_devices(0);
303       (*input_list)[logical_device_id].emplace_back(input_value);
304     }
305   }
306   return mlir::success();
307 }
308 
ParseAndValidateOutputSharding(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::SmallVector<xla::OpSharding,4> * output_sharding_list)309 mlir::LogicalResult ParseAndValidateOutputSharding(
310     const int num_cores_per_replica,
311     mlir::tf_device::ClusterFuncOp cluster_func,
312     mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list) {
313   output_sharding_list->reserve(cluster_func.getNumResults());
314 
315   const auto output_sharding_attrs =
316       cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
317           kOutputShardingAttr);
318   if (!output_sharding_attrs)
319     return cluster_func.emitError(
320         "output_sharding_configuration missing from cluster func");
321 
322   if (output_sharding_attrs.size() != cluster_func.getNumResults())
323     return cluster_func.emitError("incorrect number of output sharding");
324 
325   for (auto output_sharding_and_index :
326        llvm::enumerate(output_sharding_attrs)) {
327     const auto& output_sharding = output_sharding_and_index.value();
328     const int sharding_index = output_sharding_and_index.index();
329     if (!output_sharding.isa<mlir::StringAttr>())
330       return cluster_func.emitError(llvm::formatv(
331           "non-string output sharding at index {0}", sharding_index));
332 
333     xla::OpSharding sharding;
334     if (!sharding.ParseFromString(
335             output_sharding.cast<mlir::StringAttr>().getValue().str()))
336       return cluster_func.emitError("incorrect sharding format for outputs");
337 
338     if (sharding.type() == xla::OpSharding::OTHER &&
339         sharding.tile_assignment_devices_size() != num_cores_per_replica)
340       return cluster_func.emitError(llvm::formatv(
341           "incorrect sharding format for outputs. Number of "
342           "tiled outputs({0}) must match the number of logical "
343           "devices({1})",
344           sharding.tile_assignment_devices_size(), num_cores_per_replica));
345 
346     if (sharding.type() == xla::OpSharding::MAXIMAL &&
347         ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) ||
348          (sharding.tile_assignment_devices(0) < 0)))
349       return cluster_func.emitError(llvm::formatv(
350           "incorrect sharding format for outputs. Maximal "
351           "sharding should be assigned to device id in range "
352           "[0, {0}). Currently assigned to {1}",
353           num_cores_per_replica, sharding.tile_assignment_devices(0)));
354 
355     output_sharding_list->emplace_back(std::move(sharding));
356   }
357   return mlir::success();
358 }
359 
360 namespace {
361 
IsAssignedToLogicalDevice(const int core_id,const xla::OpSharding & sharding)362 bool IsAssignedToLogicalDevice(const int core_id,
363                                const xla::OpSharding& sharding) {
364   return sharding.type() == xla::OpSharding::MAXIMAL &&
365          sharding.tile_assignment_devices(0) == core_id;
366 }
367 
368 // Returns the index of the return value of region in
369 // `tf_device.parallel_execute` that represents cluster func output at
370 // index |cluster_func_output_index|. Regions of parallel_execute may
371 // have different return values depending on output sharding configuration.
MapClusterOutputIndexWithRegionOutputIndex(llvm::ArrayRef<xla::OpSharding> output_sharding_config,const int core_id,const int cluster_func_output_index)372 int MapClusterOutputIndexWithRegionOutputIndex(
373     llvm::ArrayRef<xla::OpSharding> output_sharding_config, const int core_id,
374     const int cluster_func_output_index) {
375   int region_output_index = 0;
376   for (int output_index = 0; output_index < cluster_func_output_index;
377        ++output_index) {
378     const auto& sharding = output_sharding_config[output_index];
379     if (sharding.type() != xla::OpSharding::MAXIMAL ||
380         IsAssignedToLogicalDevice(core_id, sharding))
381       region_output_index++;
382   }
383 
384   return region_output_index;
385 }
386 
387 // Collects tile sharded outputs from a tf_device.parallel_execute to remap from
388 // the TPU computation result.
GetTileShardedOutputsToMerge(const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ParallelExecuteOp parallel_execute)389 llvm::SmallVector<mlir::Value, 4> GetTileShardedOutputsToMerge(
390     const int cluster_func_output_index,
391     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
392     mlir::tf_device::ParallelExecuteOp parallel_execute) {
393   // Reorders outputs from TPUExecute op as defined by the output sharding
394   // configuration.
395   const xla::OpSharding& sharding =
396       output_sharding_config[cluster_func_output_index];
397   llvm::SmallVector<mlir::Value, 4> outputs_to_merge;
398   outputs_to_merge.reserve(sharding.tile_assignment_devices_size());
399   for (const auto logical_device_id : sharding.tile_assignment_devices()) {
400     const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
401         output_sharding_config, logical_device_id, cluster_func_output_index);
402     const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
403         logical_device_id)[region_output_index];
404     outputs_to_merge.emplace_back(output_from_logical_device);
405   }
406 
407   return outputs_to_merge;
408 }
409 
410 // Merges outputs from TPU computation for tile-sharded outputs.
HandleTileShardedOutputs(const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,const mlir::Location & location,mlir::Value cluster_func_output,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)411 void HandleTileShardedOutputs(
412     const int cluster_func_output_index,
413     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
414     const mlir::Location& location, mlir::Value cluster_func_output,
415     mlir::tf_device::ParallelExecuteOp parallel_execute,
416     mlir::OpBuilder* builder) {
417   // Inject concat ops after parallel_execute to merge outputs from
418   // concurrently executed computations.
419   builder->setInsertionPointAfter(parallel_execute);
420 
421   // Reorders outputs from TPUExecute op as defined by the output sharding
422   // configuration.
423   auto outputs_to_merge = GetTileShardedOutputsToMerge(
424       cluster_func_output_index, output_sharding_config, parallel_execute);
425 
426   // Creates a tree of Concat ops that merges outputs from multiple logical
427   // devices to a single replica output.
428   const xla::OpSharding& sharding =
429       output_sharding_config[cluster_func_output_index];
430   int concat_dimension = sharding.tile_assignment_dimensions_size() - 1;
431   for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) {
432     if (num_splits == 1) {
433       --concat_dimension;
434       continue;
435     }
436 
437     llvm::SmallVector<mlir::Value, 4> new_outputs;
438     new_outputs.reserve(num_splits);
439     for (int i = 0, end = outputs_to_merge.size(); i < end;
440          i = i + num_splits) {
441       mlir::TF::ConcatOp concat_op =
442           CreateConcatOp(concat_dimension, location,
443                          llvm::ArrayRef<mlir::Value>{
444                              outputs_to_merge.begin() + i,
445                              outputs_to_merge.begin() + i + num_splits},
446                          builder);
447       new_outputs.emplace_back(concat_op.getResult());
448     }
449 
450     std::swap(new_outputs, outputs_to_merge);
451     --concat_dimension;
452   }
453 
454   assert(outputs_to_merge.size() == 1);
455   cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]);
456 }
457 
ValidateAndGetTiledExecuteOutputShape(const mlir::Location & location,const mlir::TensorType cluster_func_output_type,const xla::OpSharding & output_sharding,mlir::Type * tiled_logical_computation_type)458 mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape(
459     const mlir::Location& location,
460     const mlir::TensorType cluster_func_output_type,
461     const xla::OpSharding& output_sharding,
462     mlir::Type* tiled_logical_computation_type) {
463   auto new_output_shape =
464       llvm::to_vector<4>(cluster_func_output_type.getShape());
465   for (auto dimension_and_output_splits :
466        llvm::enumerate(output_sharding.tile_assignment_dimensions())) {
467     const auto dimension_index = dimension_and_output_splits.index();
468     const auto output_splits = dimension_and_output_splits.value();
469     const auto output_shape = cluster_func_output_type.getShape();
470 
471     if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) {
472       *tiled_logical_computation_type = cluster_func_output_type;
473       break;
474     }
475 
476     auto output_shape_at_dim =
477         cluster_func_output_type.getShape()[dimension_index];
478     if (output_shape_at_dim % output_splits != 0) {
479       mlir::emitError(
480           location,
481           llvm::formatv("incorrect output sharding received. "
482                         "{0}-th dimension of the output must be "
483                         "evenly divisible by {1}, got dimension "
484                         "shape {2}",
485                         dimension_index, output_splits, output_shape_at_dim));
486     }
487 
488     new_output_shape[dimension_index] =
489         output_shape[dimension_index] / output_splits;
490   }
491 
492   *tiled_logical_computation_type = mlir::RankedTensorType::get(
493       new_output_shape, cluster_func_output_type.getElementType());
494 
495   return mlir::success();
496 }
497 
498 }  // namespace
499 
GetOutputTypesForLogicalDeviceComputation(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,llvm::SmallVectorImpl<mlir::Type> * output_types)500 mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation(
501     const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
502     mlir::tf_device::ClusterFuncOp cluster_func,
503     llvm::SmallVectorImpl<mlir::Type>* output_types) {
504   output_types->reserve(cluster_func.getNumResults());
505 
506   for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
507     const auto output_index = result_and_index.index();
508     const auto& output_sharding = output_sharding_config[output_index];
509     const auto output_sharding_type = output_sharding.type();
510     const auto cluster_func_output_type =
511         result_and_index.value().getType().cast<mlir::TensorType>();
512 
513     // If output shape of cluster func is statically known and output is tiled
514     // sharded, then the corresponding output shape of cluster func must be
515     // evenly divisible number of shardings.
516     if (output_sharding_type == xla::OpSharding::OTHER) {
517       mlir::Type tiled_logical_computation_type;
518       if (cluster_func_output_type.hasRank()) {
519         auto result = ValidateAndGetTiledExecuteOutputShape(
520             cluster_func.getLoc(), cluster_func_output_type, output_sharding,
521             &tiled_logical_computation_type);
522         if (mlir::failed(result)) return mlir::failure();
523       } else {
524         tiled_logical_computation_type = cluster_func_output_type;
525       }
526       output_types->emplace_back(tiled_logical_computation_type);
527     } else if (output_sharding_type == xla::OpSharding::REPLICATED ||
528                IsAssignedToLogicalDevice(core_id, output_sharding)) {
529       output_types->emplace_back(cluster_func_output_type);
530     }
531   }
532 
533   return mlir::success();
534 }
535 
RemapOutputsFromLogicalDevices(const mlir::Location & location,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)536 mlir::LogicalResult RemapOutputsFromLogicalDevices(
537     const mlir::Location& location,
538     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
539     mlir::tf_device::ClusterFuncOp cluster_func,
540     mlir::tf_device::ParallelExecuteOp parallel_execute,
541     mlir::OpBuilder* builder) {
542   for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
543     const auto output_index = result_and_index.index();
544     const auto cluster_func_output = result_and_index.value();
545     const auto& output_sharding = output_sharding_config[output_index];
546     const auto output_sharding_type = output_sharding.type();
547 
548     // If output is demultiplexed using the `tf.TPUPartitionedOutput` op, only
549     // replicated sharding is supported where i-th output of
550     // `tf.TPUPartitionedOutput` op maps to the output of i-th logical device.
551     // Also `tf.TPUPartitionedOutput` op must be a unique user of
552     // `tf_device.cluster_func` output.
553     mlir::TF::TPUPartitionedOutputOp partitioned_output;
554     for (auto user : cluster_func_output.getUsers()) {
555       if (auto partitioned_output_user =
556               llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedOutputOp>(user)) {
557         partitioned_output = partitioned_output_user;
558         break;
559       }
560     }
561     if (partitioned_output) {
562       if (!cluster_func_output.hasOneUse())
563         return partitioned_output.emitOpError()
564                << "must be a unique user of tf_device.cluster_func output "
565                << *cluster_func_output.getOwner();
566       if (UnsupportedPartitionedShardingType(output_sharding_type))
567         return cluster_func.emitOpError()
568                << "unsupported output sharding type "
569                << OpSharding_Type_Name(output_sharding_type) << " for "
570                << output_index << "-th output";
571 
572       if (output_sharding_type == xla::OpSharding::REPLICATED) {
573         for (auto index_and_output :
574              llvm::enumerate(partitioned_output.output())) {
575           const auto output_from_logical_device =
576               parallel_execute.GetRegionOutputs(
577                   index_and_output.index())[output_index];
578           index_and_output.value().replaceAllUsesWith(
579               output_from_logical_device);
580         }
581       } else {
582         assert(output_sharding_type == xla::OpSharding::OTHER);
583         auto tile_sharded_outputs = GetTileShardedOutputsToMerge(
584             output_index, output_sharding_config, parallel_execute);
585         for (auto result :
586              llvm::zip(partitioned_output.output(), tile_sharded_outputs))
587           std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
588       }
589       continue;
590     }
591 
592     if (output_sharding_type == xla::OpSharding::OTHER) {
593       HandleTileShardedOutputs(output_index, output_sharding_config, location,
594                                cluster_func_output, parallel_execute, builder);
595       continue;
596     }
597 
598     int logical_device_id = 0;
599     if (output_sharding_type == xla::OpSharding::MAXIMAL)
600       logical_device_id = output_sharding.tile_assignment_devices(0);
601 
602     // For maximal sharding configuration, correctly remap outputs from
603     // parallel_execute region to users of the cluster func.
604     const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
605         output_sharding_config, logical_device_id, output_index);
606 
607     const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
608         logical_device_id)[region_output_index];
609     cluster_func_output.replaceAllUsesWith(output_from_logical_device);
610   }
611   return mlir::success();
612 }
613 
GetMetadataArgumentMapping(const tpu::TPUCompileMetadataProto & metadata)614 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> GetMetadataArgumentMapping(
615     const tpu::TPUCompileMetadataProto& metadata) {
616   llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings(
617       metadata.num_cores_per_replica(), llvm::SmallVector<int64_t, 4>());
618 
619   if (metadata.num_cores_per_replica() == 1) {
620     input_mappings.front().resize(metadata.args_size());
621     std::iota(input_mappings.front().begin(), input_mappings.front().end(), 0);
622     return input_mappings;
623   }
624 
625   for (const auto& arg_and_idx : llvm::enumerate(metadata.args())) {
626     const auto& sharding = arg_and_idx.value().sharding();
627     const int64_t idx = arg_and_idx.index();
628 
629     const auto sharding_type = sharding.type();
630     if (sharding_type == xla::OpSharding::OTHER) {
631       for (const auto& device : sharding.tile_assignment_devices())
632         input_mappings[device].push_back(idx);
633     } else if (sharding_type == xla::OpSharding::REPLICATED) {
634       for (auto& input : input_mappings) input.push_back(idx);
635     } else {
636       assert(sharding_type == xla::OpSharding::MAXIMAL);
637       input_mappings[sharding.tile_assignment_devices(0)].push_back(idx);
638     }
639   }
640 
641   return input_mappings;
642 }
643 
644 }  // namespace tensorflow
645