• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/Types.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 
38 namespace tensorflow {
39 
40 const char* const kInputShardingAttr = "input_sharding_configuration";
41 const char* const kOutputShardingAttr = "output_sharding_configuration";
42 
43 namespace {
44 
45 constexpr char kNumSplitAttr[] = "num_split";
46 
47 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
48 // in 'split_dimension' dimension and returns the split values.
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location & location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)49 mlir::LogicalResult CreateSplitOp(const int num_split,
50                                   const int split_dimension,
51                                   const mlir::Location& location,
52                                   mlir::Value src_input,
53                                   mlir::OpBuilder* builder,
54                                   mlir::TF::SplitOp* split_op) {
55   // Creates a const op to hold split dimension value.
56   auto split_dim_type =
57       mlir::RankedTensorType::get({}, builder->getIntegerType(32));
58   auto split_dimension_attr =
59       mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
60   auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
61       location, split_dim_type, split_dimension_attr);
62 
63   // Correctly set output shapes of split op output if input shape is statically
64   // known.
65   mlir::Type output_type;
66   auto input_type = src_input.getType().cast<mlir::TensorType>();
67 
68   if (input_type.hasRank()) {
69     if (input_type.getShape()[split_dimension] ==
70         mlir::ShapedType::kDynamicSize) {
71       output_type = input_type;
72     } else {
73       auto shape = llvm::to_vector<4>(input_type.getShape());
74       if (shape[split_dimension] % num_split != 0) {
75         return mlir::emitError(
76             location,
77             llvm::formatv(
78                 "incorrect input sharding configuration received. "
79                 "{0}-th dimension of the input must be evenly divisible by {1}",
80                 split_dimension, num_split));
81       }
82 
83       shape[split_dimension] = shape[split_dimension] / num_split;
84       output_type =
85           mlir::RankedTensorType::get(shape, input_type.getElementType());
86     }
87   } else {
88     output_type = input_type;
89   }
90 
91   // Creates a split op that splits |src_input| along |split_dimension|.
92   llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
93   *split_op = builder->create<mlir::TF::SplitOp>(
94       location, output_types, split_dimension_op.output(), src_input);
95   (*split_op)->setAttr(
96       kNumSplitAttr,
97       builder->getIntegerAttr(builder->getIntegerType(32), num_split));
98   return mlir::success();
99 }
100 
101 // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`.
CreateConcatOp(const int concat_dimension,const mlir::Location & location,mlir::ArrayRef<mlir::Value> inputs,mlir::OpBuilder * builder,mlir::TF::ConcatOp * concat_op)102 mlir::LogicalResult CreateConcatOp(const int concat_dimension,
103                                    const mlir::Location& location,
104                                    mlir::ArrayRef<mlir::Value> inputs,
105                                    mlir::OpBuilder* builder,
106                                    mlir::TF::ConcatOp* concat_op) {
107   // Creates a const op to hold concat dimension value.
108   auto concat_dim_type =
109       mlir::RankedTensorType::get({}, builder->getIntegerType(32));
110   auto concat_dimension_attr =
111       mlir::DenseElementsAttr::get(concat_dim_type, concat_dimension);
112   auto concat_dimension_op = builder->create<mlir::TF::ConstOp>(
113       location, concat_dim_type, concat_dimension_attr);
114 
115   // Correctly set output shapes of concat op output if output shape is
116   // statically known. Since the shape of TPUExecute op must be the same
117   // across logical devices, we refer to the shape of 0th logical device
118   // computation output.
119   mlir::Type output_type;
120   auto input_type = inputs[0].getType().cast<mlir::TensorType>();
121 
122   if (input_type.hasRank()) {
123     if (input_type.getShape()[concat_dimension] ==
124         mlir::ShapedType::kDynamicSize) {
125       output_type = input_type;
126     } else {
127       auto shape = llvm::to_vector<4>(input_type.getShape());
128       shape[concat_dimension] = shape[concat_dimension] * inputs.size();
129       output_type =
130           mlir::RankedTensorType::get(shape, input_type.getElementType());
131     }
132   } else {
133     output_type = input_type;
134   }
135 
136   *concat_op = builder->create<mlir::TF::ConcatOp>(
137       location, output_type, concat_dimension_op.output(), inputs);
138   return mlir::success();
139 }
140 
141 // For tile sharded inputs to TPU computation, inject split op between the
142 // input values and TPU computation so that tiled input values are passed in
143 // as inputs to TPU computations. If more than one dimension is sharded, then
144 // a tree of connected split ops are added before tf_device.parallel_execute op.
HandleTileShardedInputs(const mlir::Location & location,const xla::OpSharding & input_sharding,const mlir::Value & original_source,mlir::OpBuilder * builder,llvm::SmallVectorImpl<mlir::Value> * tiled_inputs)145 mlir::LogicalResult HandleTileShardedInputs(
146     const mlir::Location& location, const xla::OpSharding& input_sharding,
147     const mlir::Value& original_source, mlir::OpBuilder* builder,
148     llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
149   llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
150   split_ops_for_tiled_input.reserve(
151       input_sharding.tile_assignment_devices_size());
152 
153   // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
154   // are created such that input data is sharded in row major order.
155   // Split nodes at ith depth from the original input node represent nodes
156   // that split the input data at i-th dimension.
157   const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
158   for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
159     const int num_splits = num_splits_and_index.value();
160     const int dimension_index = num_splits_and_index.index();
161     if (num_splits == 1) continue;
162 
163     // Creates root split op.
164     if (split_ops_for_tiled_input.empty()) {
165       mlir::TF::SplitOp root_split_op;
166       auto result = CreateSplitOp(num_splits, dimension_index, location,
167                                   original_source, builder, &root_split_op);
168       if (mlir::failed(result)) return mlir::failure();
169 
170       split_ops_for_tiled_input.emplace_back(root_split_op);
171       continue;
172     }
173 
174     llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
175     new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
176 
177     for (auto split_op : split_ops_for_tiled_input) {
178       for (auto parent_split_output_value : split_op.getResults()) {
179         mlir::TF::SplitOp child_split_op;
180         auto result =
181             CreateSplitOp(num_splits, dimension_index, location,
182                           parent_split_output_value, builder, &child_split_op);
183         if (mlir::failed(result)) return mlir::failure();
184 
185         new_split_ops.emplace_back(child_split_op);
186       }
187     }
188 
189     std::swap(new_split_ops, split_ops_for_tiled_input);
190   }
191 
192   // `split_ops_for_tiled_input` now includes final split nodes
193   // from which sharded data will be fed into TPUExcute ops -- sorted by
194   // row major order.
195   tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
196   for (auto split_op : split_ops_for_tiled_input)
197     tiled_inputs->append(split_op.getResults().begin(),
198                          split_op.getResults().end());
199 
200   return mlir::success();
201 }
202 
UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding)203 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) {
204   return sharding != xla::OpSharding::REPLICATED &&
205          sharding != xla::OpSharding::OTHER;
206 }
207 
208 }  // namespace
209 
ExtractInputsForLogicalDevices(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::OpBuilder * builder,llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value,4>> * input_list)210 mlir::LogicalResult ExtractInputsForLogicalDevices(
211     const int num_cores_per_replica,
212     mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder,
213     llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
214   // Initialize the input list for each logical devices.
215   input_list->reserve(num_cores_per_replica);
216   for (int i = 0; i < num_cores_per_replica; ++i)
217     input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
218 
219   llvm::SmallVector<mlir::Value, 4> cluster_func_inputs(
220       cluster_func.getOperands());
221   auto sharding_attrs =
222       cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
223           kInputShardingAttr);
224   // If sharding attribute does not exist, then all inputs are placed on 0th
225   // logical core by default.
226   if (!sharding_attrs) {
227     (*input_list)[0] = cluster_func_inputs;
228     return mlir::success();
229   }
230 
231   // Enumerate sharding configuration for each inputs. If input has replicate
232   // sharding then all logical devices take the value as input. If input has
233   // maximal sharding then only the specified logical device take the value as
234   // the input.
235   for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
236     const auto& sharding_attr = sharding_attr_and_index.value();
237     const auto input_index = sharding_attr_and_index.index();
238     const auto& input_value = cluster_func_inputs[input_index];
239 
240     xla::OpSharding sharding;
241     sharding.ParseFromString(
242         sharding_attr.cast<mlir::StringAttr>().getValue().str());
243 
244     const auto input_sharding_type = sharding.type();
245 
246     auto tiled_sharding_mismatched = [&](int tiled_input_size) {
247       return cluster_func.emitError(
248           llvm::formatv("incorrect {0}-th tiled input sharding received. "
249                         "Product of tile sharding splits({1}) must be equal to "
250                         "number of logical devices : {2}",
251                         input_index, tiled_input_size, num_cores_per_replica));
252     };
253 
254     // If input is already partitioned using the `tf.TPUPartitionedInput` op,
255     // only replicated sharding is supported where i-th operand to
256     // `tf.TPUPartitionedInput` op is input to the i-th logical device.
257     if (auto partitioned_input =
258             llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedInputOp>(
259                 input_value.getDefiningOp())) {
260       if (UnsupportedPartitionedShardingType(input_sharding_type))
261         return cluster_func->emitOpError()
262                << "unsupported input sharding type "
263                << OpSharding_Type_Name(input_sharding_type) << " for "
264                << input_index << "-th input";
265 
266       if (input_sharding_type == xla::OpSharding::REPLICATED) {
267         for (auto& index_and_inputs : llvm::enumerate(*input_list)) {
268           index_and_inputs.value().emplace_back(
269               partitioned_input.getOperand(index_and_inputs.index()));
270         }
271       } else {
272         assert(input_sharding_type == xla::OpSharding::OTHER);
273         if (partitioned_input.inputs().size() != num_cores_per_replica)
274           return tiled_sharding_mismatched(partitioned_input.inputs().size());
275 
276         for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
277           const int assigned_logical_device =
278               sharding.tile_assignment_devices(i);
279           (*input_list)[assigned_logical_device].emplace_back(
280               partitioned_input.inputs()[i]);
281         }
282       }
283       continue;
284     }
285 
286     if (input_sharding_type == xla::OpSharding::OTHER) {
287       llvm::SmallVector<mlir::Value, 4> tiled_inputs;
288       auto result = HandleTileShardedInputs(
289           cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
290       if (mlir::failed(result)) return mlir::failure();
291 
292       const int64 tiled_inputs_size = tiled_inputs.size();
293       if (tiled_inputs_size != num_cores_per_replica)
294         return tiled_sharding_mismatched(tiled_inputs.size());
295 
296       for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
297         const int assigned_logical_device = sharding.tile_assignment_devices(i);
298         (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
299       }
300     } else if (input_sharding_type == xla::OpSharding::REPLICATED) {
301       for (auto& inputs : *input_list) inputs.emplace_back(input_value);
302     } else {
303       assert(input_sharding_type == xla::OpSharding::MAXIMAL);
304       const int logical_device_id = sharding.tile_assignment_devices(0);
305       (*input_list)[logical_device_id].emplace_back(input_value);
306     }
307   }
308   return mlir::success();
309 }
310 
ParseAndValidateOutputSharding(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::SmallVector<xla::OpSharding,4> * output_sharding_list)311 mlir::LogicalResult ParseAndValidateOutputSharding(
312     const int num_cores_per_replica,
313     mlir::tf_device::ClusterFuncOp cluster_func,
314     mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list) {
315   output_sharding_list->reserve(cluster_func.getNumResults());
316 
317   const auto output_sharding_attrs =
318       cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
319           kOutputShardingAttr);
320   if (!output_sharding_attrs)
321     return cluster_func.emitError(
322         "output_sharding_configuration missing from cluster func");
323 
324   if (output_sharding_attrs.size() != cluster_func.getNumResults())
325     return cluster_func.emitError("incorrect number of output sharding");
326 
327   for (auto output_sharding_and_index :
328        llvm::enumerate(output_sharding_attrs)) {
329     const auto& output_sharding = output_sharding_and_index.value();
330     const int sharding_index = output_sharding_and_index.index();
331     if (!output_sharding.isa<mlir::StringAttr>())
332       return cluster_func.emitError(llvm::formatv(
333           "non-string output sharding at index {0}", sharding_index));
334 
335     xla::OpSharding sharding;
336     if (!sharding.ParseFromString(
337             output_sharding.cast<mlir::StringAttr>().getValue().str()))
338       return cluster_func.emitError("incorrect sharding format for outputs");
339 
340     if (sharding.type() == xla::OpSharding::OTHER &&
341         sharding.tile_assignment_devices_size() != num_cores_per_replica)
342       return cluster_func.emitError(llvm::formatv(
343           "incorrect sharding format for outputs. Number of "
344           "tiled outputs({0}) must match the number of logical "
345           "devices({1})",
346           sharding.tile_assignment_devices_size(), num_cores_per_replica));
347 
348     if (sharding.type() == xla::OpSharding::MAXIMAL &&
349         ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) ||
350          (sharding.tile_assignment_devices(0) < 0)))
351       return cluster_func.emitError(llvm::formatv(
352           "incorrect sharding format for outputs. Maximal "
353           "sharding should be assigned to device id in range "
354           "[0, {0}). Currently assigned to {1}",
355           num_cores_per_replica, sharding.tile_assignment_devices(0)));
356 
357     output_sharding_list->emplace_back(std::move(sharding));
358   }
359   return mlir::success();
360 }
361 
362 namespace {
363 
IsAssignedToLogicalDevice(const int core_id,const xla::OpSharding & sharding)364 bool IsAssignedToLogicalDevice(const int core_id,
365                                const xla::OpSharding& sharding) {
366   return sharding.type() == xla::OpSharding::MAXIMAL &&
367          sharding.tile_assignment_devices(0) == core_id;
368 }
369 
370 // Returns the index of the return value of region in
371 // `tf_device.parallel_execute` that represents cluster func output at
372 // index |cluster_func_output_index|. Regions of parallel_execute may
373 // have different return values depending on output sharding configuration.
MapClusterOutputIndexWithRegionOutputIndex(llvm::ArrayRef<xla::OpSharding> output_sharding_config,const int core_id,const int cluster_func_output_index)374 int MapClusterOutputIndexWithRegionOutputIndex(
375     llvm::ArrayRef<xla::OpSharding> output_sharding_config, const int core_id,
376     const int cluster_func_output_index) {
377   int region_output_index = 0;
378   for (int output_index = 0; output_index < cluster_func_output_index;
379        ++output_index) {
380     const auto& sharding = output_sharding_config[output_index];
381     if (sharding.type() != xla::OpSharding::MAXIMAL ||
382         IsAssignedToLogicalDevice(core_id, sharding))
383       region_output_index++;
384   }
385 
386   return region_output_index;
387 }
388 
389 // Collects tile sharded outputs from a tf_device.parallel_execute to remap from
390 // the TPU computation result.
GetTileShardedOutputsToMerge(const int cluster_func_output_index,const xla::OpSharding & sharding,mlir::tf_device::ParallelExecuteOp parallel_execute)391 llvm::SmallVector<mlir::Value, 4> GetTileShardedOutputsToMerge(
392     const int cluster_func_output_index, const xla::OpSharding& sharding,
393     mlir::tf_device::ParallelExecuteOp parallel_execute) {
394   // Reorders outputs from TPUExecute op as defined by the output sharding
395   // configuration.
396   llvm::SmallVector<mlir::Value, 4> outputs_to_merge;
397   outputs_to_merge.reserve(sharding.tile_assignment_devices_size());
398   for (const auto logical_device_id : sharding.tile_assignment_devices()) {
399     const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
400         sharding, logical_device_id, cluster_func_output_index);
401     const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
402         logical_device_id)[region_output_index];
403     outputs_to_merge.emplace_back(output_from_logical_device);
404   }
405 
406   return outputs_to_merge;
407 }
408 
409 // Merges outputs from TPU computation for tile-sharded outputs.
HandleTileShardedOutputs(const int cluster_func_output_index,const xla::OpSharding & sharding,const mlir::Location & location,mlir::Value cluster_func_output,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)410 mlir::LogicalResult HandleTileShardedOutputs(
411     const int cluster_func_output_index, const xla::OpSharding& sharding,
412     const mlir::Location& location, mlir::Value cluster_func_output,
413     mlir::tf_device::ParallelExecuteOp parallel_execute,
414     mlir::OpBuilder* builder) {
415   // Inject concat ops after parallel_execute to merge outputs from
416   // concurrently executed computations.
417   builder->setInsertionPointAfter(parallel_execute);
418 
419   // Reorders outputs from TPUExecute op as defined by the output sharding
420   // configuration.
421   auto outputs_to_merge = GetTileShardedOutputsToMerge(
422       cluster_func_output_index, sharding, parallel_execute);
423 
424   // Creates a tree of Concat ops that merges outputs from multiple logical
425   // devices to a single replica output.
426   int concat_dimension = sharding.tile_assignment_dimensions_size() - 1;
427   for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) {
428     if (num_splits == 1) {
429       --concat_dimension;
430       continue;
431     }
432 
433     llvm::SmallVector<mlir::Value, 4> new_outputs;
434     new_outputs.reserve(num_splits);
435     for (int i = 0, end = outputs_to_merge.size(); i < end;
436          i = i + num_splits) {
437       mlir::TF::ConcatOp concat_op;
438       auto result =
439           CreateConcatOp(concat_dimension, location,
440                          llvm::ArrayRef<mlir::Value>{
441                              outputs_to_merge.begin() + i,
442                              outputs_to_merge.begin() + i + num_splits},
443                          builder, &concat_op);
444       if (mlir::failed(result)) return mlir::failure();
445 
446       new_outputs.emplace_back(concat_op.getResult());
447     }
448 
449     std::swap(new_outputs, outputs_to_merge);
450     --concat_dimension;
451   }
452 
453   assert(outputs_to_merge.size() == 1);
454   cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]);
455   return mlir::success();
456 }
457 
ValidateAndGetTiledExecuteOutputShape(const mlir::Location & location,const mlir::TensorType cluster_func_output_type,const xla::OpSharding & output_sharding,mlir::Type * tiled_logical_computation_type)458 mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape(
459     const mlir::Location& location,
460     const mlir::TensorType cluster_func_output_type,
461     const xla::OpSharding& output_sharding,
462     mlir::Type* tiled_logical_computation_type) {
463   auto new_output_shape =
464       llvm::to_vector<4>(cluster_func_output_type.getShape());
465   for (auto dimension_and_output_splits :
466        llvm::enumerate(output_sharding.tile_assignment_dimensions())) {
467     const auto dimension_index = dimension_and_output_splits.index();
468     const auto output_splits = dimension_and_output_splits.value();
469     const auto output_shape = cluster_func_output_type.getShape();
470 
471     if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) {
472       *tiled_logical_computation_type = cluster_func_output_type;
473       break;
474     }
475 
476     auto output_shape_at_dim =
477         cluster_func_output_type.getShape()[dimension_index];
478     if (output_shape_at_dim % output_splits != 0) {
479       mlir::emitError(
480           location,
481           llvm::formatv("incorrect output sharding received. "
482                         "{0}-th dimension of the output must be "
483                         "evenly divisible by {1}, got dimension "
484                         "shape {2}",
485                         dimension_index, output_splits, output_shape_at_dim));
486     }
487 
488     new_output_shape[dimension_index] =
489         output_shape[dimension_index] / output_splits;
490   }
491 
492   *tiled_logical_computation_type = mlir::RankedTensorType::get(
493       new_output_shape, cluster_func_output_type.getElementType());
494 
495   return mlir::success();
496 }
497 
498 }  // namespace
499 
GetOutputTypesForLogicalDeviceComputation(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,llvm::SmallVectorImpl<mlir::Type> * output_types)500 mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation(
501     const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
502     mlir::tf_device::ClusterFuncOp cluster_func,
503     llvm::SmallVectorImpl<mlir::Type>* output_types) {
504   output_types->reserve(cluster_func.getNumResults());
505 
506   for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
507     const auto output_index = result_and_index.index();
508     const auto& output_sharding = output_sharding_config[output_index];
509     const auto output_sharding_type = output_sharding.type();
510     const auto cluster_func_output_type =
511         result_and_index.value().getType().cast<mlir::TensorType>();
512 
513     // If output shape of cluster func is statically known and output is tiled
514     // sharded, then the corresponding output shape of cluster func must be
515     // evenly divisible number of shardings.
516     if (output_sharding_type == xla::OpSharding::OTHER) {
517       mlir::Type tiled_logical_computation_type;
518       if (cluster_func_output_type.hasRank()) {
519         auto result = ValidateAndGetTiledExecuteOutputShape(
520             cluster_func.getLoc(), cluster_func_output_type, output_sharding,
521             &tiled_logical_computation_type);
522         if (mlir::failed(result)) return mlir::failure();
523       } else {
524         tiled_logical_computation_type = cluster_func_output_type;
525       }
526       output_types->emplace_back(tiled_logical_computation_type);
527     } else if (output_sharding_type == xla::OpSharding::REPLICATED ||
528                IsAssignedToLogicalDevice(core_id, output_sharding)) {
529       output_types->emplace_back(cluster_func_output_type);
530     }
531   }
532 
533   return mlir::success();
534 }
535 
RemapOutputsFromLogicalDevices(const mlir::Location & location,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)536 mlir::LogicalResult RemapOutputsFromLogicalDevices(
537     const mlir::Location& location,
538     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
539     mlir::tf_device::ClusterFuncOp cluster_func,
540     mlir::tf_device::ParallelExecuteOp parallel_execute,
541     mlir::OpBuilder* builder) {
542   for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
543     const auto output_index = result_and_index.index();
544     const auto cluster_func_output = result_and_index.value();
545     const auto& output_sharding = output_sharding_config[output_index];
546     const auto output_sharding_type = output_sharding.type();
547 
548     // If output is demultiplexed using the `tf.TPUPartitionedOutput` op, only
549     // replicated sharding is supported where i-th output of
550     // `tf.TPUPartitionedOutput` op maps to the output of i-th logical device.
551     // Also `tf.TPUPartitionedOutput` op must be a unique user of
552     // `tf_device.cluster_func` output.
553     mlir::TF::TPUPartitionedOutputOp partitioned_output;
554     for (auto user : cluster_func_output.getUsers()) {
555       if (auto partitioned_output_user =
556               llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedOutputOp>(user)) {
557         partitioned_output = partitioned_output_user;
558         break;
559       }
560     }
561     if (partitioned_output) {
562       if (!cluster_func_output.hasOneUse())
563         return partitioned_output.emitOpError()
564                << "must be a unique user of tf_device.cluster_func output "
565                << *cluster_func_output.getOwner();
566       if (UnsupportedPartitionedShardingType(output_sharding_type))
567         return cluster_func.emitOpError()
568                << "unsupported output sharding type "
569                << OpSharding_Type_Name(output_sharding_type) << " for "
570                << output_index << "-th output";
571 
572       if (output_sharding_type == xla::OpSharding::REPLICATED) {
573         for (auto index_and_output :
574              llvm::enumerate(partitioned_output.output())) {
575           const auto output_from_logical_device =
576               parallel_execute.GetRegionOutputs(
577                   index_and_output.index())[output_index];
578           index_and_output.value().replaceAllUsesWith(
579               output_from_logical_device);
580         }
581       } else {
582         assert(output_sharding_type == xla::OpSharding::OTHER);
583         auto tile_sharded_outputs = GetTileShardedOutputsToMerge(
584             output_index, output_sharding, parallel_execute);
585         for (auto result :
586              llvm::zip(partitioned_output.output(), tile_sharded_outputs))
587           std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
588       }
589       continue;
590     }
591 
592     if (output_sharding_type == xla::OpSharding::OTHER) {
593       (void)HandleTileShardedOutputs(output_index, output_sharding, location,
594                                      cluster_func_output, parallel_execute,
595                                      builder);
596       continue;
597     }
598 
599     int logical_device_id = 0;
600     if (output_sharding_type == xla::OpSharding::MAXIMAL)
601       logical_device_id = output_sharding.tile_assignment_devices(0);
602 
603     // For maximal sharding configuration, correctly remap outputs from
604     // parallel_execute region to users of the cluster func.
605     const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
606         output_sharding_config, logical_device_id, output_index);
607 
608     const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
609         logical_device_id)[region_output_index];
610     cluster_func_output.replaceAllUsesWith(output_from_logical_device);
611   }
612   return mlir::success();
613 }
614 
GetMetadataArgumentMapping(const tpu::TPUCompileMetadataProto & metadata)615 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> GetMetadataArgumentMapping(
616     const tpu::TPUCompileMetadataProto& metadata) {
617   llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings(
618       metadata.num_cores_per_replica(), llvm::SmallVector<int64_t, 4>());
619 
620   if (metadata.num_cores_per_replica() == 1) {
621     input_mappings.front().resize(metadata.args_size());
622     std::iota(input_mappings.front().begin(), input_mappings.front().end(), 0);
623     return input_mappings;
624   }
625 
626   for (const auto& arg_and_idx : llvm::enumerate(metadata.args())) {
627     const auto& sharding = arg_and_idx.value().sharding();
628     const int64_t idx = arg_and_idx.index();
629 
630     const auto sharding_type = sharding.type();
631     if (sharding_type == xla::OpSharding::OTHER) {
632       for (const auto& device : sharding.tile_assignment_devices())
633         input_mappings[device].push_back(idx);
634     } else if (sharding_type == xla::OpSharding::REPLICATED) {
635       for (auto& input : input_mappings) input.push_back(idx);
636     } else {
637       assert(sharding_type == xla::OpSharding::MAXIMAL);
638       input_mappings[sharding.tile_assignment_devices(0)].push_back(idx);
639     }
640   }
641 
642   return input_mappings;
643 }
644 
645 }  // namespace tensorflow
646