1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
17
18 #include <numeric>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Location.h" // from @llvm-project
31 #include "mlir/IR/Types.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37
38 namespace tensorflow {
39
40 const char* const kInputShardingAttr = "input_sharding_configuration";
41 const char* const kOutputShardingAttr = "output_sharding_configuration";
42
43 namespace {
44
45 constexpr char kNumSplitAttr[] = "num_split";
46
47 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
48 // in 'split_dimension' dimension and returns the split values.
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location & location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)49 mlir::LogicalResult CreateSplitOp(const int num_split,
50 const int split_dimension,
51 const mlir::Location& location,
52 mlir::Value src_input,
53 mlir::OpBuilder* builder,
54 mlir::TF::SplitOp* split_op) {
55 // Creates a const op to hold split dimension value.
56 auto split_dim_type =
57 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
58 auto split_dimension_attr =
59 mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
60 auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
61 location, split_dim_type, split_dimension_attr);
62
63 // Correctly set output shapes of split op output if input shape is statically
64 // known.
65 mlir::Type output_type;
66 auto input_type = src_input.getType().cast<mlir::TensorType>();
67
68 if (input_type.hasRank()) {
69 if (input_type.getShape()[split_dimension] ==
70 mlir::ShapedType::kDynamicSize) {
71 output_type = input_type;
72 } else {
73 auto shape = llvm::to_vector<4>(input_type.getShape());
74 if (shape[split_dimension] % num_split != 0) {
75 return mlir::emitError(
76 location,
77 llvm::formatv(
78 "incorrect input sharding configuration received. "
79 "{0}-th dimension of the input must be evenly divisible by {1}",
80 split_dimension, num_split));
81 }
82
83 shape[split_dimension] = shape[split_dimension] / num_split;
84 output_type =
85 mlir::RankedTensorType::get(shape, input_type.getElementType());
86 }
87 } else {
88 output_type = input_type;
89 }
90
91 // Creates a split op that splits |src_input| along |split_dimension|.
92 llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
93 *split_op = builder->create<mlir::TF::SplitOp>(
94 location, output_types, split_dimension_op.output(), src_input);
95 (*split_op)->setAttr(
96 kNumSplitAttr,
97 builder->getIntegerAttr(builder->getIntegerType(32), num_split));
98 return mlir::success();
99 }
100
101 // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`.
CreateConcatOp(const int concat_dimension,const mlir::Location & location,mlir::ArrayRef<mlir::Value> inputs,mlir::OpBuilder * builder)102 mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension,
103 const mlir::Location& location,
104 mlir::ArrayRef<mlir::Value> inputs,
105 mlir::OpBuilder* builder) {
106 // Creates a const op to hold concat dimension value.
107 auto concat_dim_type =
108 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
109 auto concat_dimension_attr =
110 mlir::DenseElementsAttr::get(concat_dim_type, concat_dimension);
111 auto concat_dimension_op = builder->create<mlir::TF::ConstOp>(
112 location, concat_dim_type, concat_dimension_attr);
113
114 // Correctly set output shapes of concat op output if output shape is
115 // statically known. Since the shape of TPUExecute op must be the same
116 // across logical devices, we refer to the shape of 0th logical device
117 // computation output.
118 mlir::Type output_type;
119 auto input_type = inputs[0].getType().cast<mlir::TensorType>();
120
121 if (input_type.hasRank()) {
122 if (input_type.getShape()[concat_dimension] ==
123 mlir::ShapedType::kDynamicSize) {
124 output_type = input_type;
125 } else {
126 auto shape = llvm::to_vector<4>(input_type.getShape());
127 shape[concat_dimension] = shape[concat_dimension] * inputs.size();
128 output_type =
129 mlir::RankedTensorType::get(shape, input_type.getElementType());
130 }
131 } else {
132 output_type = input_type;
133 }
134
135 return builder->create<mlir::TF::ConcatOp>(
136 location, output_type, concat_dimension_op.output(), inputs);
137 }
138
139 // For tile sharded inputs to TPU computation, inject split op between the
140 // input values and TPU computation so that tiled input values are passed in
141 // as inputs to TPU computations. If more than one dimension is sharded, then
142 // a tree of connected split ops are added before tf_device.parallel_execute op.
HandleTileShardedInputs(const mlir::Location & location,const xla::OpSharding & input_sharding,const mlir::Value & original_source,mlir::OpBuilder * builder,llvm::SmallVectorImpl<mlir::Value> * tiled_inputs)143 mlir::LogicalResult HandleTileShardedInputs(
144 const mlir::Location& location, const xla::OpSharding& input_sharding,
145 const mlir::Value& original_source, mlir::OpBuilder* builder,
146 llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
147 llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
148 split_ops_for_tiled_input.reserve(
149 input_sharding.tile_assignment_devices_size());
150
151 // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
152 // are created such that input data is sharded in row major order.
153 // Split nodes at ith depth from the original input node represent nodes
154 // that split the input data at i-th dimension.
155 const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
156 for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
157 const int num_splits = num_splits_and_index.value();
158 const int dimension_index = num_splits_and_index.index();
159 if (num_splits == 1) continue;
160
161 // Creates root split op.
162 if (split_ops_for_tiled_input.empty()) {
163 mlir::TF::SplitOp root_split_op;
164 auto result = CreateSplitOp(num_splits, dimension_index, location,
165 original_source, builder, &root_split_op);
166 if (mlir::failed(result)) return mlir::failure();
167
168 split_ops_for_tiled_input.emplace_back(root_split_op);
169 continue;
170 }
171
172 llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
173 new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
174
175 for (auto split_op : split_ops_for_tiled_input) {
176 for (auto parent_split_output_value : split_op.getResults()) {
177 mlir::TF::SplitOp child_split_op;
178 auto result =
179 CreateSplitOp(num_splits, dimension_index, location,
180 parent_split_output_value, builder, &child_split_op);
181 if (mlir::failed(result)) return mlir::failure();
182
183 new_split_ops.emplace_back(child_split_op);
184 }
185 }
186
187 std::swap(new_split_ops, split_ops_for_tiled_input);
188 }
189
190 // `split_ops_for_tiled_input` now includes final split nodes
191 // from which sharded data will be fed into TPUExcute ops -- sorted by
192 // row major order.
193 tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
194 for (auto split_op : split_ops_for_tiled_input)
195 tiled_inputs->append(split_op.getResults().begin(),
196 split_op.getResults().end());
197
198 return mlir::success();
199 }
200
UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding)201 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) {
202 return sharding != xla::OpSharding::REPLICATED &&
203 sharding != xla::OpSharding::OTHER;
204 }
205
206 } // namespace
207
ExtractInputsForLogicalDevices(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::OpBuilder * builder,llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value,4>> * input_list)208 mlir::LogicalResult ExtractInputsForLogicalDevices(
209 const int num_cores_per_replica,
210 mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder,
211 llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
212 // Initialize the input list for each logical devices.
213 input_list->reserve(num_cores_per_replica);
214 for (int i = 0; i < num_cores_per_replica; ++i)
215 input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
216
217 llvm::SmallVector<mlir::Value, 4> cluster_func_inputs(
218 cluster_func.getOperands());
219 auto sharding_attrs =
220 cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
221 kInputShardingAttr);
222 // If sharding attribute does not exist, then all inputs are placed on 0th
223 // logical core by default.
224 if (!sharding_attrs) {
225 (*input_list)[0] = cluster_func_inputs;
226 return mlir::success();
227 }
228
229 // Enumerate sharding configuration for each inputs. If input has replicate
230 // sharding then all logical devices take the value as input. If input has
231 // maximal sharding then only the specified logical device take the value as
232 // the input.
233 for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
234 const auto& sharding_attr = sharding_attr_and_index.value();
235 const auto input_index = sharding_attr_and_index.index();
236 const auto& input_value = cluster_func_inputs[input_index];
237
238 xla::OpSharding sharding;
239 sharding.ParseFromString(
240 sharding_attr.cast<mlir::StringAttr>().getValue().str());
241
242 const auto input_sharding_type = sharding.type();
243
244 auto tiled_sharding_mismatched = [&](int tiled_input_size) {
245 return cluster_func.emitError(
246 llvm::formatv("incorrect {0}-th tiled input sharding received. "
247 "Product of tile sharding splits({1}) must be equal to "
248 "number of logical devices : {2}",
249 input_index, tiled_input_size, num_cores_per_replica));
250 };
251
252 // If input is already partitioned using the `tf.TPUPartitionedInput` op,
253 // only replicated sharding is supported where i-th operand to
254 // `tf.TPUPartitionedInput` op is input to the i-th logical device.
255 if (auto partitioned_input =
256 llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedInputOp>(
257 input_value.getDefiningOp())) {
258 if (UnsupportedPartitionedShardingType(input_sharding_type))
259 return cluster_func->emitOpError()
260 << "unsupported input sharding type "
261 << OpSharding_Type_Name(input_sharding_type) << " for "
262 << input_index << "-th input";
263
264 if (input_sharding_type == xla::OpSharding::REPLICATED) {
265 for (auto& index_and_inputs : llvm::enumerate(*input_list)) {
266 index_and_inputs.value().emplace_back(
267 partitioned_input.getOperand(index_and_inputs.index()));
268 }
269 } else {
270 assert(input_sharding_type == xla::OpSharding::OTHER);
271 if (partitioned_input.inputs().size() != num_cores_per_replica)
272 return tiled_sharding_mismatched(partitioned_input.inputs().size());
273
274 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
275 const int assigned_logical_device =
276 sharding.tile_assignment_devices(i);
277 (*input_list)[assigned_logical_device].emplace_back(
278 partitioned_input.inputs()[i]);
279 }
280 }
281 continue;
282 }
283
284 if (input_sharding_type == xla::OpSharding::OTHER) {
285 llvm::SmallVector<mlir::Value, 4> tiled_inputs;
286 auto result = HandleTileShardedInputs(
287 cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
288 if (mlir::failed(result)) return mlir::failure();
289
290 const int64_t tiled_inputs_size = tiled_inputs.size();
291 if (tiled_inputs_size != num_cores_per_replica)
292 return tiled_sharding_mismatched(tiled_inputs.size());
293
294 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
295 const int assigned_logical_device = sharding.tile_assignment_devices(i);
296 (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
297 }
298 } else if (input_sharding_type == xla::OpSharding::REPLICATED) {
299 for (auto& inputs : *input_list) inputs.emplace_back(input_value);
300 } else {
301 assert(input_sharding_type == xla::OpSharding::MAXIMAL);
302 const int logical_device_id = sharding.tile_assignment_devices(0);
303 (*input_list)[logical_device_id].emplace_back(input_value);
304 }
305 }
306 return mlir::success();
307 }
308
ParseAndValidateOutputSharding(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::SmallVector<xla::OpSharding,4> * output_sharding_list)309 mlir::LogicalResult ParseAndValidateOutputSharding(
310 const int num_cores_per_replica,
311 mlir::tf_device::ClusterFuncOp cluster_func,
312 mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list) {
313 output_sharding_list->reserve(cluster_func.getNumResults());
314
315 const auto output_sharding_attrs =
316 cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
317 kOutputShardingAttr);
318 if (!output_sharding_attrs)
319 return cluster_func.emitError(
320 "output_sharding_configuration missing from cluster func");
321
322 if (output_sharding_attrs.size() != cluster_func.getNumResults())
323 return cluster_func.emitError("incorrect number of output sharding");
324
325 for (auto output_sharding_and_index :
326 llvm::enumerate(output_sharding_attrs)) {
327 const auto& output_sharding = output_sharding_and_index.value();
328 const int sharding_index = output_sharding_and_index.index();
329 if (!output_sharding.isa<mlir::StringAttr>())
330 return cluster_func.emitError(llvm::formatv(
331 "non-string output sharding at index {0}", sharding_index));
332
333 xla::OpSharding sharding;
334 if (!sharding.ParseFromString(
335 output_sharding.cast<mlir::StringAttr>().getValue().str()))
336 return cluster_func.emitError("incorrect sharding format for outputs");
337
338 if (sharding.type() == xla::OpSharding::OTHER &&
339 sharding.tile_assignment_devices_size() != num_cores_per_replica)
340 return cluster_func.emitError(llvm::formatv(
341 "incorrect sharding format for outputs. Number of "
342 "tiled outputs({0}) must match the number of logical "
343 "devices({1})",
344 sharding.tile_assignment_devices_size(), num_cores_per_replica));
345
346 if (sharding.type() == xla::OpSharding::MAXIMAL &&
347 ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) ||
348 (sharding.tile_assignment_devices(0) < 0)))
349 return cluster_func.emitError(llvm::formatv(
350 "incorrect sharding format for outputs. Maximal "
351 "sharding should be assigned to device id in range "
352 "[0, {0}). Currently assigned to {1}",
353 num_cores_per_replica, sharding.tile_assignment_devices(0)));
354
355 output_sharding_list->emplace_back(std::move(sharding));
356 }
357 return mlir::success();
358 }
359
360 namespace {
361
IsAssignedToLogicalDevice(const int core_id,const xla::OpSharding & sharding)362 bool IsAssignedToLogicalDevice(const int core_id,
363 const xla::OpSharding& sharding) {
364 return sharding.type() == xla::OpSharding::MAXIMAL &&
365 sharding.tile_assignment_devices(0) == core_id;
366 }
367
368 // Returns the index of the return value of region in
369 // `tf_device.parallel_execute` that represents cluster func output at
370 // index |cluster_func_output_index|. Regions of parallel_execute may
371 // have different return values depending on output sharding configuration.
MapClusterOutputIndexWithRegionOutputIndex(llvm::ArrayRef<xla::OpSharding> output_sharding_config,const int core_id,const int cluster_func_output_index)372 int MapClusterOutputIndexWithRegionOutputIndex(
373 llvm::ArrayRef<xla::OpSharding> output_sharding_config, const int core_id,
374 const int cluster_func_output_index) {
375 int region_output_index = 0;
376 for (int output_index = 0; output_index < cluster_func_output_index;
377 ++output_index) {
378 const auto& sharding = output_sharding_config[output_index];
379 if (sharding.type() != xla::OpSharding::MAXIMAL ||
380 IsAssignedToLogicalDevice(core_id, sharding))
381 region_output_index++;
382 }
383
384 return region_output_index;
385 }
386
387 // Collects tile sharded outputs from a tf_device.parallel_execute to remap from
388 // the TPU computation result.
GetTileShardedOutputsToMerge(const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ParallelExecuteOp parallel_execute)389 llvm::SmallVector<mlir::Value, 4> GetTileShardedOutputsToMerge(
390 const int cluster_func_output_index,
391 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
392 mlir::tf_device::ParallelExecuteOp parallel_execute) {
393 // Reorders outputs from TPUExecute op as defined by the output sharding
394 // configuration.
395 const xla::OpSharding& sharding =
396 output_sharding_config[cluster_func_output_index];
397 llvm::SmallVector<mlir::Value, 4> outputs_to_merge;
398 outputs_to_merge.reserve(sharding.tile_assignment_devices_size());
399 for (const auto logical_device_id : sharding.tile_assignment_devices()) {
400 const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
401 output_sharding_config, logical_device_id, cluster_func_output_index);
402 const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
403 logical_device_id)[region_output_index];
404 outputs_to_merge.emplace_back(output_from_logical_device);
405 }
406
407 return outputs_to_merge;
408 }
409
410 // Merges outputs from TPU computation for tile-sharded outputs.
HandleTileShardedOutputs(const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,const mlir::Location & location,mlir::Value cluster_func_output,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)411 void HandleTileShardedOutputs(
412 const int cluster_func_output_index,
413 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
414 const mlir::Location& location, mlir::Value cluster_func_output,
415 mlir::tf_device::ParallelExecuteOp parallel_execute,
416 mlir::OpBuilder* builder) {
417 // Inject concat ops after parallel_execute to merge outputs from
418 // concurrently executed computations.
419 builder->setInsertionPointAfter(parallel_execute);
420
421 // Reorders outputs from TPUExecute op as defined by the output sharding
422 // configuration.
423 auto outputs_to_merge = GetTileShardedOutputsToMerge(
424 cluster_func_output_index, output_sharding_config, parallel_execute);
425
426 // Creates a tree of Concat ops that merges outputs from multiple logical
427 // devices to a single replica output.
428 const xla::OpSharding& sharding =
429 output_sharding_config[cluster_func_output_index];
430 int concat_dimension = sharding.tile_assignment_dimensions_size() - 1;
431 for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) {
432 if (num_splits == 1) {
433 --concat_dimension;
434 continue;
435 }
436
437 llvm::SmallVector<mlir::Value, 4> new_outputs;
438 new_outputs.reserve(num_splits);
439 for (int i = 0, end = outputs_to_merge.size(); i < end;
440 i = i + num_splits) {
441 mlir::TF::ConcatOp concat_op =
442 CreateConcatOp(concat_dimension, location,
443 llvm::ArrayRef<mlir::Value>{
444 outputs_to_merge.begin() + i,
445 outputs_to_merge.begin() + i + num_splits},
446 builder);
447 new_outputs.emplace_back(concat_op.getResult());
448 }
449
450 std::swap(new_outputs, outputs_to_merge);
451 --concat_dimension;
452 }
453
454 assert(outputs_to_merge.size() == 1);
455 cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]);
456 }
457
ValidateAndGetTiledExecuteOutputShape(const mlir::Location & location,const mlir::TensorType cluster_func_output_type,const xla::OpSharding & output_sharding,mlir::Type * tiled_logical_computation_type)458 mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape(
459 const mlir::Location& location,
460 const mlir::TensorType cluster_func_output_type,
461 const xla::OpSharding& output_sharding,
462 mlir::Type* tiled_logical_computation_type) {
463 auto new_output_shape =
464 llvm::to_vector<4>(cluster_func_output_type.getShape());
465 for (auto dimension_and_output_splits :
466 llvm::enumerate(output_sharding.tile_assignment_dimensions())) {
467 const auto dimension_index = dimension_and_output_splits.index();
468 const auto output_splits = dimension_and_output_splits.value();
469 const auto output_shape = cluster_func_output_type.getShape();
470
471 if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) {
472 *tiled_logical_computation_type = cluster_func_output_type;
473 break;
474 }
475
476 auto output_shape_at_dim =
477 cluster_func_output_type.getShape()[dimension_index];
478 if (output_shape_at_dim % output_splits != 0) {
479 mlir::emitError(
480 location,
481 llvm::formatv("incorrect output sharding received. "
482 "{0}-th dimension of the output must be "
483 "evenly divisible by {1}, got dimension "
484 "shape {2}",
485 dimension_index, output_splits, output_shape_at_dim));
486 }
487
488 new_output_shape[dimension_index] =
489 output_shape[dimension_index] / output_splits;
490 }
491
492 *tiled_logical_computation_type = mlir::RankedTensorType::get(
493 new_output_shape, cluster_func_output_type.getElementType());
494
495 return mlir::success();
496 }
497
498 } // namespace
499
GetOutputTypesForLogicalDeviceComputation(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,llvm::SmallVectorImpl<mlir::Type> * output_types)500 mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation(
501 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
502 mlir::tf_device::ClusterFuncOp cluster_func,
503 llvm::SmallVectorImpl<mlir::Type>* output_types) {
504 output_types->reserve(cluster_func.getNumResults());
505
506 for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
507 const auto output_index = result_and_index.index();
508 const auto& output_sharding = output_sharding_config[output_index];
509 const auto output_sharding_type = output_sharding.type();
510 const auto cluster_func_output_type =
511 result_and_index.value().getType().cast<mlir::TensorType>();
512
513 // If output shape of cluster func is statically known and output is tiled
514 // sharded, then the corresponding output shape of cluster func must be
515 // evenly divisible number of shardings.
516 if (output_sharding_type == xla::OpSharding::OTHER) {
517 mlir::Type tiled_logical_computation_type;
518 if (cluster_func_output_type.hasRank()) {
519 auto result = ValidateAndGetTiledExecuteOutputShape(
520 cluster_func.getLoc(), cluster_func_output_type, output_sharding,
521 &tiled_logical_computation_type);
522 if (mlir::failed(result)) return mlir::failure();
523 } else {
524 tiled_logical_computation_type = cluster_func_output_type;
525 }
526 output_types->emplace_back(tiled_logical_computation_type);
527 } else if (output_sharding_type == xla::OpSharding::REPLICATED ||
528 IsAssignedToLogicalDevice(core_id, output_sharding)) {
529 output_types->emplace_back(cluster_func_output_type);
530 }
531 }
532
533 return mlir::success();
534 }
535
RemapOutputsFromLogicalDevices(const mlir::Location & location,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)536 mlir::LogicalResult RemapOutputsFromLogicalDevices(
537 const mlir::Location& location,
538 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
539 mlir::tf_device::ClusterFuncOp cluster_func,
540 mlir::tf_device::ParallelExecuteOp parallel_execute,
541 mlir::OpBuilder* builder) {
542 for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
543 const auto output_index = result_and_index.index();
544 const auto cluster_func_output = result_and_index.value();
545 const auto& output_sharding = output_sharding_config[output_index];
546 const auto output_sharding_type = output_sharding.type();
547
548 // If output is demultiplexed using the `tf.TPUPartitionedOutput` op, only
549 // replicated sharding is supported where i-th output of
550 // `tf.TPUPartitionedOutput` op maps to the output of i-th logical device.
551 // Also `tf.TPUPartitionedOutput` op must be a unique user of
552 // `tf_device.cluster_func` output.
553 mlir::TF::TPUPartitionedOutputOp partitioned_output;
554 for (auto user : cluster_func_output.getUsers()) {
555 if (auto partitioned_output_user =
556 llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedOutputOp>(user)) {
557 partitioned_output = partitioned_output_user;
558 break;
559 }
560 }
561 if (partitioned_output) {
562 if (!cluster_func_output.hasOneUse())
563 return partitioned_output.emitOpError()
564 << "must be a unique user of tf_device.cluster_func output "
565 << *cluster_func_output.getOwner();
566 if (UnsupportedPartitionedShardingType(output_sharding_type))
567 return cluster_func.emitOpError()
568 << "unsupported output sharding type "
569 << OpSharding_Type_Name(output_sharding_type) << " for "
570 << output_index << "-th output";
571
572 if (output_sharding_type == xla::OpSharding::REPLICATED) {
573 for (auto index_and_output :
574 llvm::enumerate(partitioned_output.output())) {
575 const auto output_from_logical_device =
576 parallel_execute.GetRegionOutputs(
577 index_and_output.index())[output_index];
578 index_and_output.value().replaceAllUsesWith(
579 output_from_logical_device);
580 }
581 } else {
582 assert(output_sharding_type == xla::OpSharding::OTHER);
583 auto tile_sharded_outputs = GetTileShardedOutputsToMerge(
584 output_index, output_sharding_config, parallel_execute);
585 for (auto result :
586 llvm::zip(partitioned_output.output(), tile_sharded_outputs))
587 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
588 }
589 continue;
590 }
591
592 if (output_sharding_type == xla::OpSharding::OTHER) {
593 HandleTileShardedOutputs(output_index, output_sharding_config, location,
594 cluster_func_output, parallel_execute, builder);
595 continue;
596 }
597
598 int logical_device_id = 0;
599 if (output_sharding_type == xla::OpSharding::MAXIMAL)
600 logical_device_id = output_sharding.tile_assignment_devices(0);
601
602 // For maximal sharding configuration, correctly remap outputs from
603 // parallel_execute region to users of the cluster func.
604 const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
605 output_sharding_config, logical_device_id, output_index);
606
607 const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
608 logical_device_id)[region_output_index];
609 cluster_func_output.replaceAllUsesWith(output_from_logical_device);
610 }
611 return mlir::success();
612 }
613
GetMetadataArgumentMapping(const tpu::TPUCompileMetadataProto & metadata)614 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> GetMetadataArgumentMapping(
615 const tpu::TPUCompileMetadataProto& metadata) {
616 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings(
617 metadata.num_cores_per_replica(), llvm::SmallVector<int64_t, 4>());
618
619 if (metadata.num_cores_per_replica() == 1) {
620 input_mappings.front().resize(metadata.args_size());
621 std::iota(input_mappings.front().begin(), input_mappings.front().end(), 0);
622 return input_mappings;
623 }
624
625 for (const auto& arg_and_idx : llvm::enumerate(metadata.args())) {
626 const auto& sharding = arg_and_idx.value().sharding();
627 const int64_t idx = arg_and_idx.index();
628
629 const auto sharding_type = sharding.type();
630 if (sharding_type == xla::OpSharding::OTHER) {
631 for (const auto& device : sharding.tile_assignment_devices())
632 input_mappings[device].push_back(idx);
633 } else if (sharding_type == xla::OpSharding::REPLICATED) {
634 for (auto& input : input_mappings) input.push_back(idx);
635 } else {
636 assert(sharding_type == xla::OpSharding::MAXIMAL);
637 input_mappings[sharding.tile_assignment_devices(0)].push_back(idx);
638 }
639 }
640
641 return input_mappings;
642 }
643
644 } // namespace tensorflow
645