• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/Types.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 
38 namespace tensorflow {
39 namespace {
40 
41 constexpr char kNumSplitAttr[] = "num_split";
42 
43 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
44 // in 'split_dimension' dimension and returns the split values.
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location & location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)45 mlir::LogicalResult CreateSplitOp(const int num_split,
46                                   const int split_dimension,
47                                   const mlir::Location& location,
48                                   mlir::Value src_input,
49                                   mlir::OpBuilder* builder,
50                                   mlir::TF::SplitOp* split_op) {
51   // Creates a const op to hold split dimension value.
52   auto split_dim_type =
53       mlir::RankedTensorType::get({}, builder->getIntegerType(32));
54   auto split_dimension_attr =
55       mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
56   auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
57       location, split_dim_type, split_dimension_attr);
58 
59   // Correctly set output shapes of split op output if input shape is statically
60   // known.
61   mlir::Type output_type;
62   auto input_type = src_input.getType().cast<mlir::TensorType>();
63 
64   if (input_type.hasRank()) {
65     if (input_type.getShape()[split_dimension] ==
66         mlir::ShapedType::kDynamicSize) {
67       output_type = input_type;
68     } else {
69       auto shape = llvm::to_vector<4>(input_type.getShape());
70       if (shape[split_dimension] % num_split != 0) {
71         return mlir::emitError(
72             location,
73             llvm::formatv(
74                 "incorrect input sharding configuration received. "
75                 "{0}-th dimension of the input must be evenly divisible by {1}",
76                 split_dimension, num_split));
77       }
78 
79       shape[split_dimension] = shape[split_dimension] / num_split;
80       output_type =
81           mlir::RankedTensorType::get(shape, input_type.getElementType());
82     }
83   } else {
84     output_type = input_type;
85   }
86 
87   // Creates a split op that splits |src_input| along |split_dimension|.
88   llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
89   *split_op = builder->create<mlir::TF::SplitOp>(
90       location, output_types, split_dimension_op.output(), src_input);
91   (*split_op)->setAttr(
92       kNumSplitAttr,
93       builder->getIntegerAttr(builder->getIntegerType(32), num_split));
94   return mlir::success();
95 }
96 
97 // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`.
CreateConcatOp(const int concat_dimension,const mlir::Location & location,mlir::ArrayRef<mlir::Value> inputs,mlir::OpBuilder * builder)98 mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension,
99                                   const mlir::Location& location,
100                                   mlir::ArrayRef<mlir::Value> inputs,
101                                   mlir::OpBuilder* builder) {
102   // Creates a const op to hold concat dimension value.
103   auto concat_dim_type =
104       mlir::RankedTensorType::get({}, builder->getIntegerType(32));
105   auto concat_dimension_attr =
106       mlir::DenseElementsAttr::get(concat_dim_type, concat_dimension);
107   auto concat_dimension_op = builder->create<mlir::TF::ConstOp>(
108       location, concat_dim_type, concat_dimension_attr);
109 
110   // Correctly set output shapes of concat op output if output shape is
111   // statically known. Since the shape of TPUExecute op must be the same
112   // across logical devices, we refer to the shape of 0th logical device
113   // computation output.
114   mlir::Type output_type;
115   auto input_type = inputs[0].getType().cast<mlir::TensorType>();
116 
117   if (input_type.hasRank()) {
118     if (input_type.getShape()[concat_dimension] ==
119         mlir::ShapedType::kDynamicSize) {
120       output_type = input_type;
121     } else {
122       auto shape = llvm::to_vector<4>(input_type.getShape());
123       shape[concat_dimension] = shape[concat_dimension] * inputs.size();
124       output_type =
125           mlir::RankedTensorType::get(shape, input_type.getElementType());
126     }
127   } else {
128     output_type = input_type;
129   }
130 
131   return builder->create<mlir::TF::ConcatOp>(
132       location, output_type, concat_dimension_op.output(), inputs);
133 }
134 
135 // For tile sharded inputs to TPU computation, inject split op between the
136 // input values and TPU computation so that tiled input values are passed in
137 // as inputs to TPU computations. If more than one dimension is sharded, then
138 // a tree of connected split ops are added before tf_device.parallel_execute op.
HandleTileShardedInputs(const mlir::Location & location,const xla::OpSharding & input_sharding,const mlir::Value & original_source,mlir::OpBuilder * builder,llvm::SmallVectorImpl<mlir::Value> * tiled_inputs)139 mlir::LogicalResult HandleTileShardedInputs(
140     const mlir::Location& location, const xla::OpSharding& input_sharding,
141     const mlir::Value& original_source, mlir::OpBuilder* builder,
142     llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
143   llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
144   split_ops_for_tiled_input.reserve(
145       input_sharding.tile_assignment_devices_size());
146 
147   // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
148   // are created such that input data is sharded in row major order.
149   // Split nodes at ith depth from the original input node represent nodes
150   // that split the input data at i-th dimension.
151   const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
152   for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
153     const int num_splits = num_splits_and_index.value();
154     const int dimension_index = num_splits_and_index.index();
155     if (num_splits == 1) continue;
156 
157     // Creates root split op.
158     if (split_ops_for_tiled_input.empty()) {
159       mlir::TF::SplitOp root_split_op;
160       auto result = CreateSplitOp(num_splits, dimension_index, location,
161                                   original_source, builder, &root_split_op);
162       if (mlir::failed(result)) return mlir::failure();
163 
164       split_ops_for_tiled_input.emplace_back(root_split_op);
165       continue;
166     }
167 
168     llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
169     new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
170 
171     for (auto split_op : split_ops_for_tiled_input) {
172       for (auto parent_split_output_value : split_op.getResults()) {
173         mlir::TF::SplitOp child_split_op;
174         auto result =
175             CreateSplitOp(num_splits, dimension_index, location,
176                           parent_split_output_value, builder, &child_split_op);
177         if (mlir::failed(result)) return mlir::failure();
178 
179         new_split_ops.emplace_back(child_split_op);
180       }
181     }
182 
183     std::swap(new_split_ops, split_ops_for_tiled_input);
184   }
185 
186   // `split_ops_for_tiled_input` now includes final split nodes
187   // from which sharded data will be fed into TPUExcute ops -- sorted by
188   // row major order.
189   tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
190   for (auto split_op : split_ops_for_tiled_input)
191     tiled_inputs->append(split_op.getResults().begin(),
192                          split_op.getResults().end());
193 
194   return mlir::success();
195 }
196 
UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding)197 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) {
198   return sharding != xla::OpSharding::REPLICATED &&
199          sharding != xla::OpSharding::OTHER;
200 }
201 
202 }  // namespace
203 
ExtractInputsForLogicalDevices(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::OpBuilder * builder,llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value,4>> * input_list)204 mlir::LogicalResult ExtractInputsForLogicalDevices(
205     const int num_cores_per_replica,
206     mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder,
207     llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
208   // Initialize the input list for each logical devices.
209   input_list->reserve(num_cores_per_replica);
210   for (int i = 0; i < num_cores_per_replica; ++i)
211     input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
212 
213   llvm::SmallVector<mlir::Value, 4> cluster_func_inputs(
214       cluster_func.getOperands());
215   auto sharding_attrs =
216       cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
217           kInputShardingAttr);
218   // If sharding attribute does not exist, then all inputs are placed on 0th
219   // logical core by default.
220   if (!sharding_attrs) {
221     (*input_list)[0] = cluster_func_inputs;
222     return mlir::success();
223   }
224 
225   // Enumerate sharding configuration for each inputs. If input has replicate
226   // sharding then all logical devices take the value as input. If input has
227   // maximal sharding then only the specified logical device take the value as
228   // the input.
229   for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
230     const auto& sharding_attr = sharding_attr_and_index.value();
231     const auto input_index = sharding_attr_and_index.index();
232     const auto& input_value = cluster_func_inputs[input_index];
233 
234     xla::OpSharding sharding;
235     sharding.ParseFromString(
236         sharding_attr.cast<mlir::StringAttr>().getValue().str());
237 
238     const auto input_sharding_type = sharding.type();
239 
240     auto tiled_sharding_mismatched = [&](int tiled_input_size) {
241       return cluster_func.emitError(
242           llvm::formatv("incorrect {0}-th tiled input sharding received. "
243                         "Product of tile sharding splits({1}) must be equal to "
244                         "number of logical devices : {2}",
245                         input_index, tiled_input_size, num_cores_per_replica));
246     };
247 
248     // If input is already partitioned using the `tf.TPUPartitionedInput` op,
249     // only replicated sharding is supported where i-th operand to
250     // `tf.TPUPartitionedInput` op is input to the i-th logical device.
251     if (auto partitioned_input =
252             llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedInputOp>(
253                 input_value.getDefiningOp())) {
254       if (UnsupportedPartitionedShardingType(input_sharding_type))
255         return cluster_func->emitOpError()
256                << "unsupported input sharding type "
257                << OpSharding_Type_Name(input_sharding_type) << " for "
258                << input_index << "-th input";
259 
260       if (input_sharding_type == xla::OpSharding::REPLICATED) {
261         for (auto& index_and_inputs : llvm::enumerate(*input_list)) {
262           index_and_inputs.value().emplace_back(
263               partitioned_input.getOperand(index_and_inputs.index()));
264         }
265       } else {
266         assert(input_sharding_type == xla::OpSharding::OTHER);
267         if (partitioned_input.inputs().size() != num_cores_per_replica)
268           return tiled_sharding_mismatched(partitioned_input.inputs().size());
269 
270         for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
271           const int assigned_logical_device =
272               sharding.tile_assignment_devices(i);
273           (*input_list)[assigned_logical_device].emplace_back(
274               partitioned_input.inputs()[i]);
275         }
276       }
277       continue;
278     }
279 
280     if (input_sharding_type == xla::OpSharding::OTHER) {
281       llvm::SmallVector<mlir::Value, 4> tiled_inputs;
282       auto result = HandleTileShardedInputs(
283           cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
284       if (mlir::failed(result)) return mlir::failure();
285 
286       const int64_t tiled_inputs_size = tiled_inputs.size();
287       if (tiled_inputs_size != num_cores_per_replica)
288         return tiled_sharding_mismatched(tiled_inputs.size());
289 
290       for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
291         const int assigned_logical_device = sharding.tile_assignment_devices(i);
292         (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
293       }
294     } else if (input_sharding_type == xla::OpSharding::REPLICATED) {
295       for (auto& inputs : *input_list) inputs.emplace_back(input_value);
296     } else {
297       assert(input_sharding_type == xla::OpSharding::MAXIMAL);
298       const int logical_device_id = sharding.tile_assignment_devices(0);
299       (*input_list)[logical_device_id].emplace_back(input_value);
300     }
301   }
302   return mlir::success();
303 }
304 
ParseAndValidateOutputSharding(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::SmallVector<xla::OpSharding,4> * output_sharding_list)305 mlir::LogicalResult ParseAndValidateOutputSharding(
306     const int num_cores_per_replica,
307     mlir::tf_device::ClusterFuncOp cluster_func,
308     mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list) {
309   output_sharding_list->reserve(cluster_func.getNumResults());
310 
311   const auto output_sharding_attrs =
312       cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
313           kOutputShardingAttr);
314   if (!output_sharding_attrs)
315     return cluster_func.emitError(
316         "output_sharding_configuration missing from cluster func");
317 
318   if (output_sharding_attrs.size() != cluster_func.getNumResults())
319     return cluster_func.emitError("incorrect number of output sharding");
320 
321   for (auto output_sharding_and_index :
322        llvm::enumerate(output_sharding_attrs)) {
323     const auto& output_sharding = output_sharding_and_index.value();
324     const int sharding_index = output_sharding_and_index.index();
325     if (!output_sharding.isa<mlir::StringAttr>())
326       return cluster_func.emitError(llvm::formatv(
327           "non-string output sharding at index {0}", sharding_index));
328 
329     xla::OpSharding sharding;
330     if (!sharding.ParseFromString(
331             output_sharding.cast<mlir::StringAttr>().getValue().str()))
332       return cluster_func.emitError("incorrect sharding format for outputs");
333 
334     if (sharding.type() == xla::OpSharding::OTHER &&
335         sharding.tile_assignment_devices_size() != num_cores_per_replica)
336       return cluster_func.emitError(llvm::formatv(
337           "incorrect sharding format for outputs. Number of "
338           "tiled outputs({0}) must match the number of logical "
339           "devices({1})",
340           sharding.tile_assignment_devices_size(), num_cores_per_replica));
341 
342     if (sharding.type() == xla::OpSharding::MAXIMAL &&
343         ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) ||
344          (sharding.tile_assignment_devices(0) < 0)))
345       return cluster_func.emitError(llvm::formatv(
346           "incorrect sharding format for outputs. Maximal "
347           "sharding should be assigned to device id in range "
348           "[0, {0}). Currently assigned to {1}",
349           num_cores_per_replica, sharding.tile_assignment_devices(0)));
350 
351     output_sharding_list->emplace_back(std::move(sharding));
352   }
353   return mlir::success();
354 }
355 
356 namespace {
357 
IsAssignedToLogicalDevice(const int core_id,const xla::OpSharding & sharding)358 bool IsAssignedToLogicalDevice(const int core_id,
359                                const xla::OpSharding& sharding) {
360   return sharding.type() == xla::OpSharding::MAXIMAL &&
361          sharding.tile_assignment_devices(0) == core_id;
362 }
363 
364 // Returns the index of the return value of region in
365 // `tf_device.parallel_execute` that represents cluster func output at
366 // index |cluster_func_output_index|. Regions of parallel_execute may
367 // have different return values depending on output sharding configuration.
LookupClusterToCoreIndex(const mlir::Location & location,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,const int core_id,const int cluster_func_output_index,int * core_output_index)368 mlir::LogicalResult LookupClusterToCoreIndex(
369     const mlir::Location& location,
370     llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
371     const int core_id, const int cluster_func_output_index,
372     int* core_output_index) {
373   *core_output_index =
374       cluster_to_core_index[core_id][cluster_func_output_index];
375   if (*core_output_index == -1) {
376     mlir::emitError(
377         location,
378         llvm::formatv("Attempted to map cluster_func output index {0} to "
379                       "program assigned to core {1}. The tensor at this output "
380                       "index was not assigned or sharded to this core.",
381                       cluster_func_output_index, core_id));
382     return mlir::failure();
383   }
384   return mlir::success();
385 }
386 
387 // Collects tile sharded outputs from a tf_device.parallel_execute to remap from
388 // the TPU computation result.
GetTileShardedOutputsToMerge(const mlir::Location & location,const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,int cluster_idx,mlir::tf_device::ParallelExecuteOp new_parallel_execute,llvm::SmallVector<mlir::Value,4> * outputs_to_merge)389 mlir::LogicalResult GetTileShardedOutputsToMerge(
390     const mlir::Location& location, const int cluster_func_output_index,
391     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
392     llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
393     int cluster_idx, mlir::tf_device::ParallelExecuteOp new_parallel_execute,
394     llvm::SmallVector<mlir::Value, 4>* outputs_to_merge) {
395   // Reorders outputs from TPUExecute op as defined by the output sharding
396   // configuration.
397   const xla::OpSharding& sharding =
398       output_sharding_config[cluster_func_output_index];
399   outputs_to_merge->reserve(sharding.tile_assignment_devices_size());
400   for (const auto logical_device_id : sharding.tile_assignment_devices()) {
401     int region_output_index;
402     auto status = LookupClusterToCoreIndex(
403         location, cluster_to_core_index, logical_device_id,
404         cluster_func_output_index, &region_output_index);
405     if (failed(status)) return mlir::failure();
406     const auto output_from_logical_device =
407         new_parallel_execute.GetRegionOutputs(
408             cluster_idx + logical_device_id)[region_output_index];
409     outputs_to_merge->emplace_back(output_from_logical_device);
410   }
411 
412   return mlir::success();
413 }
414 
415 // Merges outputs from TPU computation for tile-sharded outputs.
HandleTileShardedOutputs(const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,const mlir::Location & location,mlir::Value cluster_func_output,int cluster_idx,mlir::tf_device::ParallelExecuteOp new_parallel_execute,mlir::OpBuilder * builder)416 mlir::LogicalResult HandleTileShardedOutputs(
417     const int cluster_func_output_index,
418     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
419     llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
420     const mlir::Location& location, mlir::Value cluster_func_output,
421     int cluster_idx, mlir::tf_device::ParallelExecuteOp new_parallel_execute,
422     mlir::OpBuilder* builder) {
423   // Inject concat ops after parallel_execute to merge outputs from
424   // concurrently executed computations.
425   builder->setInsertionPointAfter(new_parallel_execute);
426 
427   // Reorders outputs from TPUExecute op as defined by the output sharding
428   // configuration.
429   llvm::SmallVector<mlir::Value, 4> outputs_to_merge;
430   auto status = GetTileShardedOutputsToMerge(
431       location, cluster_func_output_index, output_sharding_config,
432       cluster_to_core_index, cluster_idx, new_parallel_execute,
433       &outputs_to_merge);
434   if (failed(status)) return mlir::failure();
435 
436   // Creates a tree of Concat ops that merges outputs from multiple logical
437   // devices to a single replica output.
438   const xla::OpSharding& sharding =
439       output_sharding_config[cluster_func_output_index];
440   int concat_dimension = sharding.tile_assignment_dimensions_size() - 1;
441   for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) {
442     if (num_splits == 1) {
443       --concat_dimension;
444       continue;
445     }
446 
447     llvm::SmallVector<mlir::Value, 4> new_outputs;
448     new_outputs.reserve(num_splits);
449     for (int i = 0, end = outputs_to_merge.size(); i < end;
450          i = i + num_splits) {
451       mlir::TF::ConcatOp concat_op =
452           CreateConcatOp(concat_dimension, location,
453                          llvm::ArrayRef<mlir::Value>{
454                              outputs_to_merge.begin() + i,
455                              outputs_to_merge.begin() + i + num_splits},
456                          builder);
457       new_outputs.emplace_back(concat_op.getResult());
458     }
459 
460     std::swap(new_outputs, outputs_to_merge);
461     --concat_dimension;
462   }
463 
464   assert(outputs_to_merge.size() == 1);
465   cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]);
466   return mlir::success();
467 }
468 
ValidateAndGetTiledExecuteOutputShape(const mlir::Location & location,const mlir::TensorType cluster_func_output_type,const xla::OpSharding & output_sharding,mlir::Type * tiled_logical_computation_type)469 mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape(
470     const mlir::Location& location,
471     const mlir::TensorType cluster_func_output_type,
472     const xla::OpSharding& output_sharding,
473     mlir::Type* tiled_logical_computation_type) {
474   auto new_output_shape =
475       llvm::to_vector<4>(cluster_func_output_type.getShape());
476   for (auto dimension_and_output_splits :
477        llvm::enumerate(output_sharding.tile_assignment_dimensions())) {
478     const auto dimension_index = dimension_and_output_splits.index();
479     const auto output_splits = dimension_and_output_splits.value();
480     const auto output_shape = cluster_func_output_type.getShape();
481 
482     if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) {
483       *tiled_logical_computation_type = cluster_func_output_type;
484       break;
485     }
486 
487     auto output_shape_at_dim =
488         cluster_func_output_type.getShape()[dimension_index];
489     if (output_shape_at_dim % output_splits != 0) {
490       mlir::emitError(
491           location,
492           llvm::formatv("incorrect output sharding received. "
493                         "{0}-th dimension of the output must be "
494                         "evenly divisible by {1}, got dimension "
495                         "shape {2}",
496                         dimension_index, output_splits, output_shape_at_dim));
497     }
498 
499     new_output_shape[dimension_index] =
500         output_shape[dimension_index] / output_splits;
501   }
502 
503   *tiled_logical_computation_type = mlir::RankedTensorType::get(
504       new_output_shape, cluster_func_output_type.getElementType());
505 
506   return mlir::success();
507 }
508 
509 }  // namespace
510 
GetOutputTypesForLogicalDeviceComputation(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,llvm::SmallVectorImpl<mlir::Type> * output_types,llvm::SmallVectorImpl<int> * cluster_to_core_index)511 mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation(
512     const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
513     mlir::tf_device::ClusterFuncOp cluster_func,
514     llvm::SmallVectorImpl<mlir::Type>* output_types,
515     llvm::SmallVectorImpl<int>* cluster_to_core_index) {
516   output_types->reserve(cluster_func.getNumResults());
517 
518   int core_index = 0;
519   for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
520     const auto output_index = result_and_index.index();
521     const auto& output_sharding = output_sharding_config[output_index];
522     const auto output_sharding_type = output_sharding.type();
523     const auto cluster_func_output_type =
524         result_and_index.value().getType().cast<mlir::TensorType>();
525 
526     // If output shape of cluster func is statically known and output is tiled
527     // sharded, then the corresponding output shape of cluster func must be
528     // evenly divisible number of shardings.
529     if (output_sharding_type == xla::OpSharding::OTHER) {
530       mlir::Type tiled_logical_computation_type;
531       if (cluster_func_output_type.hasRank()) {
532         auto result = ValidateAndGetTiledExecuteOutputShape(
533             cluster_func.getLoc(), cluster_func_output_type, output_sharding,
534             &tiled_logical_computation_type);
535         if (mlir::failed(result)) return mlir::failure();
536       } else {
537         tiled_logical_computation_type = cluster_func_output_type;
538       }
539       cluster_to_core_index->emplace_back(core_index++);
540       output_types->emplace_back(tiled_logical_computation_type);
541     } else if (output_sharding_type == xla::OpSharding::REPLICATED ||
542                IsAssignedToLogicalDevice(core_id, output_sharding)) {
543       cluster_to_core_index->emplace_back(core_index++);
544       output_types->emplace_back(cluster_func_output_type);
545     } else {
546       cluster_to_core_index->emplace_back(-1);
547     }
548   }
549 
550   return mlir::success();
551 }
552 
RemapOutputsFromLogicalDevices(const mlir::Location & location,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,mlir::tf_device::ParallelExecuteOp old_parallel_execute,int cluster_idx,mlir::tf_device::ParallelExecuteOp new_parallel_execute,mlir::OpBuilder * builder)553 mlir::LogicalResult RemapOutputsFromLogicalDevices(
554     const mlir::Location& location,
555     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
556     llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
557     mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx,
558     mlir::tf_device::ParallelExecuteOp new_parallel_execute,
559     mlir::OpBuilder* builder) {
560   for (auto& result_and_index :
561        llvm::enumerate(old_parallel_execute.getResults())) {
562     const auto output_index = result_and_index.index();
563     const auto old_parallel_execute_output = result_and_index.value();
564     const auto& output_sharding = output_sharding_config[output_index];
565     const auto output_sharding_type = output_sharding.type();
566 
567     // If output is demultiplexed using the `tf.TPUPartitionedOutput` op, only
568     // replicated sharding is supported where i-th output of
569     // `tf.TPUPartitionedOutput` op maps to the output of i-th logical device.
570     // Also `tf.TPUPartitionedOutput` op must be a unique user of
571     // TPU Cluster (`tf_device.old_parallel_execute`) output.
572     mlir::TF::TPUPartitionedOutputOp partitioned_output;
573     for (auto user : old_parallel_execute_output.getUsers()) {
574       if (auto partitioned_output_user =
575               llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedOutputOp>(user)) {
576         partitioned_output = partitioned_output_user;
577         break;
578       }
579     }
580     if (partitioned_output) {
581       if (!old_parallel_execute_output.hasOneUse())
582         return partitioned_output.emitOpError()
583                << "must be a unique user of TPU Cluster "
584                   "(tf_device.old_parallel_execute) output "
585                << *old_parallel_execute_output.getOwner();
586       if (UnsupportedPartitionedShardingType(output_sharding_type))
587         return old_parallel_execute.emitOpError()
588                << "unsupported output sharding type "
589                << OpSharding_Type_Name(output_sharding_type) << " for "
590                << output_index << "-th output";
591 
592       if (output_sharding_type == xla::OpSharding::REPLICATED) {
593         for (auto index_and_output :
594              llvm::enumerate(partitioned_output.output())) {
595           const auto output_from_logical_device =
596               new_parallel_execute.GetRegionOutputs(
597                   cluster_idx + index_and_output.index())[output_index];
598           index_and_output.value().replaceAllUsesWith(
599               output_from_logical_device);
600         }
601       } else {
602         assert(output_sharding_type == xla::OpSharding::OTHER);
603         llvm::SmallVector<mlir::Value, 4> tile_sharded_outputs;
604         if (failed(GetTileShardedOutputsToMerge(
605                 location, output_index, output_sharding_config,
606                 cluster_to_core_index, cluster_idx, new_parallel_execute,
607                 &tile_sharded_outputs)))
608           return mlir::failure();
609         for (auto result :
610              llvm::zip(partitioned_output.output(), tile_sharded_outputs))
611           std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
612       }
613       continue;
614     }
615 
616     if (output_sharding_type == xla::OpSharding::OTHER) {
617       if (failed(HandleTileShardedOutputs(
618               output_index, output_sharding_config, cluster_to_core_index,
619               location, old_parallel_execute_output, cluster_idx,
620               new_parallel_execute, builder)))
621         return mlir::failure();
622       continue;
623     }
624 
625     int logical_device_id = 0;
626     if (output_sharding_type == xla::OpSharding::MAXIMAL)
627       logical_device_id = output_sharding.tile_assignment_devices(0);
628 
629     // For maximal sharding configuration, correctly remap outputs from
630     // parallel_execute region to users of the cluster func.
631     int region_output_index;
632     if (failed(LookupClusterToCoreIndex(location, cluster_to_core_index,
633                                         logical_device_id, output_index,
634                                         &region_output_index)))
635       return mlir::failure();
636 
637     const auto output_from_logical_device =
638         new_parallel_execute.GetRegionOutputs(
639             cluster_idx + logical_device_id)[region_output_index];
640     old_parallel_execute_output.replaceAllUsesWith(output_from_logical_device);
641   }
642   return mlir::success();
643 }
644 
GetMetadataArgumentMapping(const tpu::TPUCompileMetadataProto & metadata)645 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> GetMetadataArgumentMapping(
646     const tpu::TPUCompileMetadataProto& metadata) {
647   llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings(
648       metadata.num_cores_per_replica(), llvm::SmallVector<int64_t, 4>());
649 
650   if (metadata.num_cores_per_replica() == 1) {
651     input_mappings.front().resize(metadata.args_size());
652     std::iota(input_mappings.front().begin(), input_mappings.front().end(), 0);
653     return input_mappings;
654   }
655 
656   for (const auto& arg_and_idx : llvm::enumerate(metadata.args())) {
657     const auto& sharding = arg_and_idx.value().sharding();
658     const int64_t idx = arg_and_idx.index();
659 
660     const auto sharding_type = sharding.type();
661     if (sharding_type == xla::OpSharding::OTHER) {
662       for (const auto& device : sharding.tile_assignment_devices())
663         input_mappings[device].push_back(idx);
664     } else if (sharding_type == xla::OpSharding::REPLICATED) {
665       for (auto& input : input_mappings) input.push_back(idx);
666     } else {
667       assert(sharding_type == xla::OpSharding::MAXIMAL);
668       input_mappings[sharding.tile_assignment_devices(0)].push_back(idx);
669     }
670   }
671 
672   return input_mappings;
673 }
674 
675 }  // namespace tensorflow
676