1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
17
18 #include <numeric>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Location.h" // from @llvm-project
31 #include "mlir/IR/Types.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37
38 namespace tensorflow {
39
40 const char* const kInputShardingAttr = "input_sharding_configuration";
41 const char* const kOutputShardingAttr = "output_sharding_configuration";
42
43 namespace {
44
45 constexpr char kNumSplitAttr[] = "num_split";
46
47 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
48 // in 'split_dimension' dimension and returns the split values.
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location & location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)49 mlir::LogicalResult CreateSplitOp(const int num_split,
50 const int split_dimension,
51 const mlir::Location& location,
52 mlir::Value src_input,
53 mlir::OpBuilder* builder,
54 mlir::TF::SplitOp* split_op) {
55 // Creates a const op to hold split dimension value.
56 auto split_dim_type =
57 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
58 auto split_dimension_attr =
59 mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
60 auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
61 location, split_dim_type, split_dimension_attr);
62
63 // Correctly set output shapes of split op output if input shape is statically
64 // known.
65 mlir::Type output_type;
66 auto input_type = src_input.getType().cast<mlir::TensorType>();
67
68 if (input_type.hasRank()) {
69 if (input_type.getShape()[split_dimension] ==
70 mlir::ShapedType::kDynamicSize) {
71 output_type = input_type;
72 } else {
73 auto shape = llvm::to_vector<4>(input_type.getShape());
74 if (shape[split_dimension] % num_split != 0) {
75 return mlir::emitError(
76 location,
77 llvm::formatv(
78 "incorrect input sharding configuration received. "
79 "{0}-th dimension of the input must be evenly divisible by {1}",
80 split_dimension, num_split));
81 }
82
83 shape[split_dimension] = shape[split_dimension] / num_split;
84 output_type =
85 mlir::RankedTensorType::get(shape, input_type.getElementType());
86 }
87 } else {
88 output_type = input_type;
89 }
90
91 // Creates a split op that splits |src_input| along |split_dimension|.
92 llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
93 *split_op = builder->create<mlir::TF::SplitOp>(
94 location, output_types, split_dimension_op.output(), src_input);
95 (*split_op)->setAttr(
96 kNumSplitAttr,
97 builder->getIntegerAttr(builder->getIntegerType(32), num_split));
98 return mlir::success();
99 }
100
101 // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`.
CreateConcatOp(const int concat_dimension,const mlir::Location & location,mlir::ArrayRef<mlir::Value> inputs,mlir::OpBuilder * builder,mlir::TF::ConcatOp * concat_op)102 mlir::LogicalResult CreateConcatOp(const int concat_dimension,
103 const mlir::Location& location,
104 mlir::ArrayRef<mlir::Value> inputs,
105 mlir::OpBuilder* builder,
106 mlir::TF::ConcatOp* concat_op) {
107 // Creates a const op to hold concat dimension value.
108 auto concat_dim_type =
109 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
110 auto concat_dimension_attr =
111 mlir::DenseElementsAttr::get(concat_dim_type, concat_dimension);
112 auto concat_dimension_op = builder->create<mlir::TF::ConstOp>(
113 location, concat_dim_type, concat_dimension_attr);
114
115 // Correctly set output shapes of concat op output if output shape is
116 // statically known. Since the shape of TPUExecute op must be the same
117 // across logical devices, we refer to the shape of 0th logical device
118 // computation output.
119 mlir::Type output_type;
120 auto input_type = inputs[0].getType().cast<mlir::TensorType>();
121
122 if (input_type.hasRank()) {
123 if (input_type.getShape()[concat_dimension] ==
124 mlir::ShapedType::kDynamicSize) {
125 output_type = input_type;
126 } else {
127 auto shape = llvm::to_vector<4>(input_type.getShape());
128 shape[concat_dimension] = shape[concat_dimension] * inputs.size();
129 output_type =
130 mlir::RankedTensorType::get(shape, input_type.getElementType());
131 }
132 } else {
133 output_type = input_type;
134 }
135
136 *concat_op = builder->create<mlir::TF::ConcatOp>(
137 location, output_type, concat_dimension_op.output(), inputs);
138 return mlir::success();
139 }
140
141 // For tile sharded inputs to TPU computation, inject split op between the
142 // input values and TPU computation so that tiled input values are passed in
143 // as inputs to TPU computations. If more than one dimension is sharded, then
144 // a tree of connected split ops are added before tf_device.parallel_execute op.
HandleTileShardedInputs(const mlir::Location & location,const xla::OpSharding & input_sharding,const mlir::Value & original_source,mlir::OpBuilder * builder,llvm::SmallVectorImpl<mlir::Value> * tiled_inputs)145 mlir::LogicalResult HandleTileShardedInputs(
146 const mlir::Location& location, const xla::OpSharding& input_sharding,
147 const mlir::Value& original_source, mlir::OpBuilder* builder,
148 llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
149 llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
150 split_ops_for_tiled_input.reserve(
151 input_sharding.tile_assignment_devices_size());
152
153 // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
154 // are created such that input data is sharded in row major order.
155 // Split nodes at ith depth from the original input node represent nodes
156 // that split the input data at i-th dimension.
157 const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
158 for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
159 const int num_splits = num_splits_and_index.value();
160 const int dimension_index = num_splits_and_index.index();
161 if (num_splits == 1) continue;
162
163 // Creates root split op.
164 if (split_ops_for_tiled_input.empty()) {
165 mlir::TF::SplitOp root_split_op;
166 auto result = CreateSplitOp(num_splits, dimension_index, location,
167 original_source, builder, &root_split_op);
168 if (mlir::failed(result)) return mlir::failure();
169
170 split_ops_for_tiled_input.emplace_back(root_split_op);
171 continue;
172 }
173
174 llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
175 new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
176
177 for (auto split_op : split_ops_for_tiled_input) {
178 for (auto parent_split_output_value : split_op.getResults()) {
179 mlir::TF::SplitOp child_split_op;
180 auto result =
181 CreateSplitOp(num_splits, dimension_index, location,
182 parent_split_output_value, builder, &child_split_op);
183 if (mlir::failed(result)) return mlir::failure();
184
185 new_split_ops.emplace_back(child_split_op);
186 }
187 }
188
189 std::swap(new_split_ops, split_ops_for_tiled_input);
190 }
191
192 // `split_ops_for_tiled_input` now includes final split nodes
193 // from which sharded data will be fed into TPUExcute ops -- sorted by
194 // row major order.
195 tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
196 for (auto split_op : split_ops_for_tiled_input)
197 tiled_inputs->append(split_op.getResults().begin(),
198 split_op.getResults().end());
199
200 return mlir::success();
201 }
202
UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding)203 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) {
204 return sharding != xla::OpSharding::REPLICATED &&
205 sharding != xla::OpSharding::OTHER;
206 }
207
208 } // namespace
209
ExtractInputsForLogicalDevices(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::OpBuilder * builder,llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value,4>> * input_list)210 mlir::LogicalResult ExtractInputsForLogicalDevices(
211 const int num_cores_per_replica,
212 mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder,
213 llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
214 // Initialize the input list for each logical devices.
215 input_list->reserve(num_cores_per_replica);
216 for (int i = 0; i < num_cores_per_replica; ++i)
217 input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
218
219 llvm::SmallVector<mlir::Value, 4> cluster_func_inputs(
220 cluster_func.getOperands());
221 auto sharding_attrs =
222 cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
223 kInputShardingAttr);
224 // If sharding attribute does not exist, then all inputs are placed on 0th
225 // logical core by default.
226 if (!sharding_attrs) {
227 (*input_list)[0] = cluster_func_inputs;
228 return mlir::success();
229 }
230
231 // Enumerate sharding configuration for each inputs. If input has replicate
232 // sharding then all logical devices take the value as input. If input has
233 // maximal sharding then only the specified logical device take the value as
234 // the input.
235 for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
236 const auto& sharding_attr = sharding_attr_and_index.value();
237 const auto input_index = sharding_attr_and_index.index();
238 const auto& input_value = cluster_func_inputs[input_index];
239
240 xla::OpSharding sharding;
241 sharding.ParseFromString(
242 sharding_attr.cast<mlir::StringAttr>().getValue().str());
243
244 const auto input_sharding_type = sharding.type();
245
246 auto tiled_sharding_mismatched = [&](int tiled_input_size) {
247 return cluster_func.emitError(
248 llvm::formatv("incorrect {0}-th tiled input sharding received. "
249 "Product of tile sharding splits({1}) must be equal to "
250 "number of logical devices : {2}",
251 input_index, tiled_input_size, num_cores_per_replica));
252 };
253
254 // If input is already partitioned using the `tf.TPUPartitionedInput` op,
255 // only replicated sharding is supported where i-th operand to
256 // `tf.TPUPartitionedInput` op is input to the i-th logical device.
257 if (auto partitioned_input =
258 llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedInputOp>(
259 input_value.getDefiningOp())) {
260 if (UnsupportedPartitionedShardingType(input_sharding_type))
261 return cluster_func->emitOpError()
262 << "unsupported input sharding type "
263 << OpSharding_Type_Name(input_sharding_type) << " for "
264 << input_index << "-th input";
265
266 if (input_sharding_type == xla::OpSharding::REPLICATED) {
267 for (auto& index_and_inputs : llvm::enumerate(*input_list)) {
268 index_and_inputs.value().emplace_back(
269 partitioned_input.getOperand(index_and_inputs.index()));
270 }
271 } else {
272 assert(input_sharding_type == xla::OpSharding::OTHER);
273 if (partitioned_input.inputs().size() != num_cores_per_replica)
274 return tiled_sharding_mismatched(partitioned_input.inputs().size());
275
276 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
277 const int assigned_logical_device =
278 sharding.tile_assignment_devices(i);
279 (*input_list)[assigned_logical_device].emplace_back(
280 partitioned_input.inputs()[i]);
281 }
282 }
283 continue;
284 }
285
286 if (input_sharding_type == xla::OpSharding::OTHER) {
287 llvm::SmallVector<mlir::Value, 4> tiled_inputs;
288 auto result = HandleTileShardedInputs(
289 cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
290 if (mlir::failed(result)) return mlir::failure();
291
292 const int64 tiled_inputs_size = tiled_inputs.size();
293 if (tiled_inputs_size != num_cores_per_replica)
294 return tiled_sharding_mismatched(tiled_inputs.size());
295
296 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
297 const int assigned_logical_device = sharding.tile_assignment_devices(i);
298 (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
299 }
300 } else if (input_sharding_type == xla::OpSharding::REPLICATED) {
301 for (auto& inputs : *input_list) inputs.emplace_back(input_value);
302 } else {
303 assert(input_sharding_type == xla::OpSharding::MAXIMAL);
304 const int logical_device_id = sharding.tile_assignment_devices(0);
305 (*input_list)[logical_device_id].emplace_back(input_value);
306 }
307 }
308 return mlir::success();
309 }
310
ParseAndValidateOutputSharding(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::SmallVector<xla::OpSharding,4> * output_sharding_list)311 mlir::LogicalResult ParseAndValidateOutputSharding(
312 const int num_cores_per_replica,
313 mlir::tf_device::ClusterFuncOp cluster_func,
314 mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list) {
315 output_sharding_list->reserve(cluster_func.getNumResults());
316
317 const auto output_sharding_attrs =
318 cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
319 kOutputShardingAttr);
320 if (!output_sharding_attrs)
321 return cluster_func.emitError(
322 "output_sharding_configuration missing from cluster func");
323
324 if (output_sharding_attrs.size() != cluster_func.getNumResults())
325 return cluster_func.emitError("incorrect number of output sharding");
326
327 for (auto output_sharding_and_index :
328 llvm::enumerate(output_sharding_attrs)) {
329 const auto& output_sharding = output_sharding_and_index.value();
330 const int sharding_index = output_sharding_and_index.index();
331 if (!output_sharding.isa<mlir::StringAttr>())
332 return cluster_func.emitError(llvm::formatv(
333 "non-string output sharding at index {0}", sharding_index));
334
335 xla::OpSharding sharding;
336 if (!sharding.ParseFromString(
337 output_sharding.cast<mlir::StringAttr>().getValue().str()))
338 return cluster_func.emitError("incorrect sharding format for outputs");
339
340 if (sharding.type() == xla::OpSharding::OTHER &&
341 sharding.tile_assignment_devices_size() != num_cores_per_replica)
342 return cluster_func.emitError(llvm::formatv(
343 "incorrect sharding format for outputs. Number of "
344 "tiled outputs({0}) must match the number of logical "
345 "devices({1})",
346 sharding.tile_assignment_devices_size(), num_cores_per_replica));
347
348 if (sharding.type() == xla::OpSharding::MAXIMAL &&
349 ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) ||
350 (sharding.tile_assignment_devices(0) < 0)))
351 return cluster_func.emitError(llvm::formatv(
352 "incorrect sharding format for outputs. Maximal "
353 "sharding should be assigned to device id in range "
354 "[0, {0}). Currently assigned to {1}",
355 num_cores_per_replica, sharding.tile_assignment_devices(0)));
356
357 output_sharding_list->emplace_back(std::move(sharding));
358 }
359 return mlir::success();
360 }
361
362 namespace {
363
IsAssignedToLogicalDevice(const int core_id,const xla::OpSharding & sharding)364 bool IsAssignedToLogicalDevice(const int core_id,
365 const xla::OpSharding& sharding) {
366 return sharding.type() == xla::OpSharding::MAXIMAL &&
367 sharding.tile_assignment_devices(0) == core_id;
368 }
369
370 // Returns the index of the return value of region in
371 // `tf_device.parallel_execute` that represents cluster func output at
372 // index |cluster_func_output_index|. Regions of parallel_execute may
373 // have different return values depending on output sharding configuration.
MapClusterOutputIndexWithRegionOutputIndex(llvm::ArrayRef<xla::OpSharding> output_sharding_config,const int core_id,const int cluster_func_output_index)374 int MapClusterOutputIndexWithRegionOutputIndex(
375 llvm::ArrayRef<xla::OpSharding> output_sharding_config, const int core_id,
376 const int cluster_func_output_index) {
377 int region_output_index = 0;
378 for (int output_index = 0; output_index < cluster_func_output_index;
379 ++output_index) {
380 const auto& sharding = output_sharding_config[output_index];
381 if (sharding.type() != xla::OpSharding::MAXIMAL ||
382 IsAssignedToLogicalDevice(core_id, sharding))
383 region_output_index++;
384 }
385
386 return region_output_index;
387 }
388
389 // Collects tile sharded outputs from a tf_device.parallel_execute to remap from
390 // the TPU computation result.
GetTileShardedOutputsToMerge(const int cluster_func_output_index,const xla::OpSharding & sharding,mlir::tf_device::ParallelExecuteOp parallel_execute)391 llvm::SmallVector<mlir::Value, 4> GetTileShardedOutputsToMerge(
392 const int cluster_func_output_index, const xla::OpSharding& sharding,
393 mlir::tf_device::ParallelExecuteOp parallel_execute) {
394 // Reorders outputs from TPUExecute op as defined by the output sharding
395 // configuration.
396 llvm::SmallVector<mlir::Value, 4> outputs_to_merge;
397 outputs_to_merge.reserve(sharding.tile_assignment_devices_size());
398 for (const auto logical_device_id : sharding.tile_assignment_devices()) {
399 const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
400 sharding, logical_device_id, cluster_func_output_index);
401 const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
402 logical_device_id)[region_output_index];
403 outputs_to_merge.emplace_back(output_from_logical_device);
404 }
405
406 return outputs_to_merge;
407 }
408
409 // Merges outputs from TPU computation for tile-sharded outputs.
HandleTileShardedOutputs(const int cluster_func_output_index,const xla::OpSharding & sharding,const mlir::Location & location,mlir::Value cluster_func_output,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)410 mlir::LogicalResult HandleTileShardedOutputs(
411 const int cluster_func_output_index, const xla::OpSharding& sharding,
412 const mlir::Location& location, mlir::Value cluster_func_output,
413 mlir::tf_device::ParallelExecuteOp parallel_execute,
414 mlir::OpBuilder* builder) {
415 // Inject concat ops after parallel_execute to merge outputs from
416 // concurrently executed computations.
417 builder->setInsertionPointAfter(parallel_execute);
418
419 // Reorders outputs from TPUExecute op as defined by the output sharding
420 // configuration.
421 auto outputs_to_merge = GetTileShardedOutputsToMerge(
422 cluster_func_output_index, sharding, parallel_execute);
423
424 // Creates a tree of Concat ops that merges outputs from multiple logical
425 // devices to a single replica output.
426 int concat_dimension = sharding.tile_assignment_dimensions_size() - 1;
427 for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) {
428 if (num_splits == 1) {
429 --concat_dimension;
430 continue;
431 }
432
433 llvm::SmallVector<mlir::Value, 4> new_outputs;
434 new_outputs.reserve(num_splits);
435 for (int i = 0, end = outputs_to_merge.size(); i < end;
436 i = i + num_splits) {
437 mlir::TF::ConcatOp concat_op;
438 auto result =
439 CreateConcatOp(concat_dimension, location,
440 llvm::ArrayRef<mlir::Value>{
441 outputs_to_merge.begin() + i,
442 outputs_to_merge.begin() + i + num_splits},
443 builder, &concat_op);
444 if (mlir::failed(result)) return mlir::failure();
445
446 new_outputs.emplace_back(concat_op.getResult());
447 }
448
449 std::swap(new_outputs, outputs_to_merge);
450 --concat_dimension;
451 }
452
453 assert(outputs_to_merge.size() == 1);
454 cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]);
455 return mlir::success();
456 }
457
ValidateAndGetTiledExecuteOutputShape(const mlir::Location & location,const mlir::TensorType cluster_func_output_type,const xla::OpSharding & output_sharding,mlir::Type * tiled_logical_computation_type)458 mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape(
459 const mlir::Location& location,
460 const mlir::TensorType cluster_func_output_type,
461 const xla::OpSharding& output_sharding,
462 mlir::Type* tiled_logical_computation_type) {
463 auto new_output_shape =
464 llvm::to_vector<4>(cluster_func_output_type.getShape());
465 for (auto dimension_and_output_splits :
466 llvm::enumerate(output_sharding.tile_assignment_dimensions())) {
467 const auto dimension_index = dimension_and_output_splits.index();
468 const auto output_splits = dimension_and_output_splits.value();
469 const auto output_shape = cluster_func_output_type.getShape();
470
471 if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) {
472 *tiled_logical_computation_type = cluster_func_output_type;
473 break;
474 }
475
476 auto output_shape_at_dim =
477 cluster_func_output_type.getShape()[dimension_index];
478 if (output_shape_at_dim % output_splits != 0) {
479 mlir::emitError(
480 location,
481 llvm::formatv("incorrect output sharding received. "
482 "{0}-th dimension of the output must be "
483 "evenly divisible by {1}, got dimension "
484 "shape {2}",
485 dimension_index, output_splits, output_shape_at_dim));
486 }
487
488 new_output_shape[dimension_index] =
489 output_shape[dimension_index] / output_splits;
490 }
491
492 *tiled_logical_computation_type = mlir::RankedTensorType::get(
493 new_output_shape, cluster_func_output_type.getElementType());
494
495 return mlir::success();
496 }
497
498 } // namespace
499
GetOutputTypesForLogicalDeviceComputation(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,llvm::SmallVectorImpl<mlir::Type> * output_types)500 mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation(
501 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
502 mlir::tf_device::ClusterFuncOp cluster_func,
503 llvm::SmallVectorImpl<mlir::Type>* output_types) {
504 output_types->reserve(cluster_func.getNumResults());
505
506 for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
507 const auto output_index = result_and_index.index();
508 const auto& output_sharding = output_sharding_config[output_index];
509 const auto output_sharding_type = output_sharding.type();
510 const auto cluster_func_output_type =
511 result_and_index.value().getType().cast<mlir::TensorType>();
512
513 // If output shape of cluster func is statically known and output is tiled
514 // sharded, then the corresponding output shape of cluster func must be
515 // evenly divisible number of shardings.
516 if (output_sharding_type == xla::OpSharding::OTHER) {
517 mlir::Type tiled_logical_computation_type;
518 if (cluster_func_output_type.hasRank()) {
519 auto result = ValidateAndGetTiledExecuteOutputShape(
520 cluster_func.getLoc(), cluster_func_output_type, output_sharding,
521 &tiled_logical_computation_type);
522 if (mlir::failed(result)) return mlir::failure();
523 } else {
524 tiled_logical_computation_type = cluster_func_output_type;
525 }
526 output_types->emplace_back(tiled_logical_computation_type);
527 } else if (output_sharding_type == xla::OpSharding::REPLICATED ||
528 IsAssignedToLogicalDevice(core_id, output_sharding)) {
529 output_types->emplace_back(cluster_func_output_type);
530 }
531 }
532
533 return mlir::success();
534 }
535
RemapOutputsFromLogicalDevices(const mlir::Location & location,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,mlir::tf_device::ParallelExecuteOp parallel_execute,mlir::OpBuilder * builder)536 mlir::LogicalResult RemapOutputsFromLogicalDevices(
537 const mlir::Location& location,
538 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
539 mlir::tf_device::ClusterFuncOp cluster_func,
540 mlir::tf_device::ParallelExecuteOp parallel_execute,
541 mlir::OpBuilder* builder) {
542 for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
543 const auto output_index = result_and_index.index();
544 const auto cluster_func_output = result_and_index.value();
545 const auto& output_sharding = output_sharding_config[output_index];
546 const auto output_sharding_type = output_sharding.type();
547
548 // If output is demultiplexed using the `tf.TPUPartitionedOutput` op, only
549 // replicated sharding is supported where i-th output of
550 // `tf.TPUPartitionedOutput` op maps to the output of i-th logical device.
551 // Also `tf.TPUPartitionedOutput` op must be a unique user of
552 // `tf_device.cluster_func` output.
553 mlir::TF::TPUPartitionedOutputOp partitioned_output;
554 for (auto user : cluster_func_output.getUsers()) {
555 if (auto partitioned_output_user =
556 llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedOutputOp>(user)) {
557 partitioned_output = partitioned_output_user;
558 break;
559 }
560 }
561 if (partitioned_output) {
562 if (!cluster_func_output.hasOneUse())
563 return partitioned_output.emitOpError()
564 << "must be a unique user of tf_device.cluster_func output "
565 << *cluster_func_output.getOwner();
566 if (UnsupportedPartitionedShardingType(output_sharding_type))
567 return cluster_func.emitOpError()
568 << "unsupported output sharding type "
569 << OpSharding_Type_Name(output_sharding_type) << " for "
570 << output_index << "-th output";
571
572 if (output_sharding_type == xla::OpSharding::REPLICATED) {
573 for (auto index_and_output :
574 llvm::enumerate(partitioned_output.output())) {
575 const auto output_from_logical_device =
576 parallel_execute.GetRegionOutputs(
577 index_and_output.index())[output_index];
578 index_and_output.value().replaceAllUsesWith(
579 output_from_logical_device);
580 }
581 } else {
582 assert(output_sharding_type == xla::OpSharding::OTHER);
583 auto tile_sharded_outputs = GetTileShardedOutputsToMerge(
584 output_index, output_sharding, parallel_execute);
585 for (auto result :
586 llvm::zip(partitioned_output.output(), tile_sharded_outputs))
587 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
588 }
589 continue;
590 }
591
592 if (output_sharding_type == xla::OpSharding::OTHER) {
593 (void)HandleTileShardedOutputs(output_index, output_sharding, location,
594 cluster_func_output, parallel_execute,
595 builder);
596 continue;
597 }
598
599 int logical_device_id = 0;
600 if (output_sharding_type == xla::OpSharding::MAXIMAL)
601 logical_device_id = output_sharding.tile_assignment_devices(0);
602
603 // For maximal sharding configuration, correctly remap outputs from
604 // parallel_execute region to users of the cluster func.
605 const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex(
606 output_sharding_config, logical_device_id, output_index);
607
608 const auto output_from_logical_device = parallel_execute.GetRegionOutputs(
609 logical_device_id)[region_output_index];
610 cluster_func_output.replaceAllUsesWith(output_from_logical_device);
611 }
612 return mlir::success();
613 }
614
GetMetadataArgumentMapping(const tpu::TPUCompileMetadataProto & metadata)615 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> GetMetadataArgumentMapping(
616 const tpu::TPUCompileMetadataProto& metadata) {
617 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings(
618 metadata.num_cores_per_replica(), llvm::SmallVector<int64_t, 4>());
619
620 if (metadata.num_cores_per_replica() == 1) {
621 input_mappings.front().resize(metadata.args_size());
622 std::iota(input_mappings.front().begin(), input_mappings.front().end(), 0);
623 return input_mappings;
624 }
625
626 for (const auto& arg_and_idx : llvm::enumerate(metadata.args())) {
627 const auto& sharding = arg_and_idx.value().sharding();
628 const int64_t idx = arg_and_idx.index();
629
630 const auto sharding_type = sharding.type();
631 if (sharding_type == xla::OpSharding::OTHER) {
632 for (const auto& device : sharding.tile_assignment_devices())
633 input_mappings[device].push_back(idx);
634 } else if (sharding_type == xla::OpSharding::REPLICATED) {
635 for (auto& input : input_mappings) input.push_back(idx);
636 } else {
637 assert(sharding_type == xla::OpSharding::MAXIMAL);
638 input_mappings[sharding.tile_assignment_devices(0)].push_back(idx);
639 }
640 }
641
642 return input_mappings;
643 }
644
645 } // namespace tensorflow
646