1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
17
18 #include <numeric>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Location.h" // from @llvm-project
31 #include "mlir/IR/Types.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37
38 namespace tensorflow {
39 namespace {
40
41 constexpr char kNumSplitAttr[] = "num_split";
42
43 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
44 // in 'split_dimension' dimension and returns the split values.
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location & location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)45 mlir::LogicalResult CreateSplitOp(const int num_split,
46 const int split_dimension,
47 const mlir::Location& location,
48 mlir::Value src_input,
49 mlir::OpBuilder* builder,
50 mlir::TF::SplitOp* split_op) {
51 // Creates a const op to hold split dimension value.
52 auto split_dim_type =
53 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
54 auto split_dimension_attr =
55 mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
56 auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
57 location, split_dim_type, split_dimension_attr);
58
59 // Correctly set output shapes of split op output if input shape is statically
60 // known.
61 mlir::Type output_type;
62 auto input_type = src_input.getType().cast<mlir::TensorType>();
63
64 if (input_type.hasRank()) {
65 if (input_type.getShape()[split_dimension] ==
66 mlir::ShapedType::kDynamicSize) {
67 output_type = input_type;
68 } else {
69 auto shape = llvm::to_vector<4>(input_type.getShape());
70 if (shape[split_dimension] % num_split != 0) {
71 return mlir::emitError(
72 location,
73 llvm::formatv(
74 "incorrect input sharding configuration received. "
75 "{0}-th dimension of the input must be evenly divisible by {1}",
76 split_dimension, num_split));
77 }
78
79 shape[split_dimension] = shape[split_dimension] / num_split;
80 output_type =
81 mlir::RankedTensorType::get(shape, input_type.getElementType());
82 }
83 } else {
84 output_type = input_type;
85 }
86
87 // Creates a split op that splits |src_input| along |split_dimension|.
88 llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
89 *split_op = builder->create<mlir::TF::SplitOp>(
90 location, output_types, split_dimension_op.output(), src_input);
91 (*split_op)->setAttr(
92 kNumSplitAttr,
93 builder->getIntegerAttr(builder->getIntegerType(32), num_split));
94 return mlir::success();
95 }
96
97 // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`.
CreateConcatOp(const int concat_dimension,const mlir::Location & location,mlir::ArrayRef<mlir::Value> inputs,mlir::OpBuilder * builder)98 mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension,
99 const mlir::Location& location,
100 mlir::ArrayRef<mlir::Value> inputs,
101 mlir::OpBuilder* builder) {
102 // Creates a const op to hold concat dimension value.
103 auto concat_dim_type =
104 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
105 auto concat_dimension_attr =
106 mlir::DenseElementsAttr::get(concat_dim_type, concat_dimension);
107 auto concat_dimension_op = builder->create<mlir::TF::ConstOp>(
108 location, concat_dim_type, concat_dimension_attr);
109
110 // Correctly set output shapes of concat op output if output shape is
111 // statically known. Since the shape of TPUExecute op must be the same
112 // across logical devices, we refer to the shape of 0th logical device
113 // computation output.
114 mlir::Type output_type;
115 auto input_type = inputs[0].getType().cast<mlir::TensorType>();
116
117 if (input_type.hasRank()) {
118 if (input_type.getShape()[concat_dimension] ==
119 mlir::ShapedType::kDynamicSize) {
120 output_type = input_type;
121 } else {
122 auto shape = llvm::to_vector<4>(input_type.getShape());
123 shape[concat_dimension] = shape[concat_dimension] * inputs.size();
124 output_type =
125 mlir::RankedTensorType::get(shape, input_type.getElementType());
126 }
127 } else {
128 output_type = input_type;
129 }
130
131 return builder->create<mlir::TF::ConcatOp>(
132 location, output_type, concat_dimension_op.output(), inputs);
133 }
134
135 // For tile sharded inputs to TPU computation, inject split op between the
136 // input values and TPU computation so that tiled input values are passed in
137 // as inputs to TPU computations. If more than one dimension is sharded, then
138 // a tree of connected split ops are added before tf_device.parallel_execute op.
HandleTileShardedInputs(const mlir::Location & location,const xla::OpSharding & input_sharding,const mlir::Value & original_source,mlir::OpBuilder * builder,llvm::SmallVectorImpl<mlir::Value> * tiled_inputs)139 mlir::LogicalResult HandleTileShardedInputs(
140 const mlir::Location& location, const xla::OpSharding& input_sharding,
141 const mlir::Value& original_source, mlir::OpBuilder* builder,
142 llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
143 llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
144 split_ops_for_tiled_input.reserve(
145 input_sharding.tile_assignment_devices_size());
146
147 // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
148 // are created such that input data is sharded in row major order.
149 // Split nodes at ith depth from the original input node represent nodes
150 // that split the input data at i-th dimension.
151 const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
152 for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
153 const int num_splits = num_splits_and_index.value();
154 const int dimension_index = num_splits_and_index.index();
155 if (num_splits == 1) continue;
156
157 // Creates root split op.
158 if (split_ops_for_tiled_input.empty()) {
159 mlir::TF::SplitOp root_split_op;
160 auto result = CreateSplitOp(num_splits, dimension_index, location,
161 original_source, builder, &root_split_op);
162 if (mlir::failed(result)) return mlir::failure();
163
164 split_ops_for_tiled_input.emplace_back(root_split_op);
165 continue;
166 }
167
168 llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
169 new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
170
171 for (auto split_op : split_ops_for_tiled_input) {
172 for (auto parent_split_output_value : split_op.getResults()) {
173 mlir::TF::SplitOp child_split_op;
174 auto result =
175 CreateSplitOp(num_splits, dimension_index, location,
176 parent_split_output_value, builder, &child_split_op);
177 if (mlir::failed(result)) return mlir::failure();
178
179 new_split_ops.emplace_back(child_split_op);
180 }
181 }
182
183 std::swap(new_split_ops, split_ops_for_tiled_input);
184 }
185
186 // `split_ops_for_tiled_input` now includes final split nodes
187 // from which sharded data will be fed into TPUExcute ops -- sorted by
188 // row major order.
189 tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
190 for (auto split_op : split_ops_for_tiled_input)
191 tiled_inputs->append(split_op.getResults().begin(),
192 split_op.getResults().end());
193
194 return mlir::success();
195 }
196
UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding)197 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) {
198 return sharding != xla::OpSharding::REPLICATED &&
199 sharding != xla::OpSharding::OTHER;
200 }
201
202 } // namespace
203
ExtractInputsForLogicalDevices(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::OpBuilder * builder,llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value,4>> * input_list)204 mlir::LogicalResult ExtractInputsForLogicalDevices(
205 const int num_cores_per_replica,
206 mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder,
207 llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
208 // Initialize the input list for each logical devices.
209 input_list->reserve(num_cores_per_replica);
210 for (int i = 0; i < num_cores_per_replica; ++i)
211 input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
212
213 llvm::SmallVector<mlir::Value, 4> cluster_func_inputs(
214 cluster_func.getOperands());
215 auto sharding_attrs =
216 cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
217 kInputShardingAttr);
218 // If sharding attribute does not exist, then all inputs are placed on 0th
219 // logical core by default.
220 if (!sharding_attrs) {
221 (*input_list)[0] = cluster_func_inputs;
222 return mlir::success();
223 }
224
225 // Enumerate sharding configuration for each inputs. If input has replicate
226 // sharding then all logical devices take the value as input. If input has
227 // maximal sharding then only the specified logical device take the value as
228 // the input.
229 for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
230 const auto& sharding_attr = sharding_attr_and_index.value();
231 const auto input_index = sharding_attr_and_index.index();
232 const auto& input_value = cluster_func_inputs[input_index];
233
234 xla::OpSharding sharding;
235 sharding.ParseFromString(
236 sharding_attr.cast<mlir::StringAttr>().getValue().str());
237
238 const auto input_sharding_type = sharding.type();
239
240 auto tiled_sharding_mismatched = [&](int tiled_input_size) {
241 return cluster_func.emitError(
242 llvm::formatv("incorrect {0}-th tiled input sharding received. "
243 "Product of tile sharding splits({1}) must be equal to "
244 "number of logical devices : {2}",
245 input_index, tiled_input_size, num_cores_per_replica));
246 };
247
248 // If input is already partitioned using the `tf.TPUPartitionedInput` op,
249 // only replicated sharding is supported where i-th operand to
250 // `tf.TPUPartitionedInput` op is input to the i-th logical device.
251 if (auto partitioned_input =
252 llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedInputOp>(
253 input_value.getDefiningOp())) {
254 if (UnsupportedPartitionedShardingType(input_sharding_type))
255 return cluster_func->emitOpError()
256 << "unsupported input sharding type "
257 << OpSharding_Type_Name(input_sharding_type) << " for "
258 << input_index << "-th input";
259
260 if (input_sharding_type == xla::OpSharding::REPLICATED) {
261 for (auto& index_and_inputs : llvm::enumerate(*input_list)) {
262 index_and_inputs.value().emplace_back(
263 partitioned_input.getOperand(index_and_inputs.index()));
264 }
265 } else {
266 assert(input_sharding_type == xla::OpSharding::OTHER);
267 if (partitioned_input.inputs().size() != num_cores_per_replica)
268 return tiled_sharding_mismatched(partitioned_input.inputs().size());
269
270 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
271 const int assigned_logical_device =
272 sharding.tile_assignment_devices(i);
273 (*input_list)[assigned_logical_device].emplace_back(
274 partitioned_input.inputs()[i]);
275 }
276 }
277 continue;
278 }
279
280 if (input_sharding_type == xla::OpSharding::OTHER) {
281 llvm::SmallVector<mlir::Value, 4> tiled_inputs;
282 auto result = HandleTileShardedInputs(
283 cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
284 if (mlir::failed(result)) return mlir::failure();
285
286 const int64_t tiled_inputs_size = tiled_inputs.size();
287 if (tiled_inputs_size != num_cores_per_replica)
288 return tiled_sharding_mismatched(tiled_inputs.size());
289
290 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
291 const int assigned_logical_device = sharding.tile_assignment_devices(i);
292 (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
293 }
294 } else if (input_sharding_type == xla::OpSharding::REPLICATED) {
295 for (auto& inputs : *input_list) inputs.emplace_back(input_value);
296 } else {
297 assert(input_sharding_type == xla::OpSharding::MAXIMAL);
298 const int logical_device_id = sharding.tile_assignment_devices(0);
299 (*input_list)[logical_device_id].emplace_back(input_value);
300 }
301 }
302 return mlir::success();
303 }
304
ParseAndValidateOutputSharding(const int num_cores_per_replica,mlir::tf_device::ClusterFuncOp cluster_func,mlir::SmallVector<xla::OpSharding,4> * output_sharding_list)305 mlir::LogicalResult ParseAndValidateOutputSharding(
306 const int num_cores_per_replica,
307 mlir::tf_device::ClusterFuncOp cluster_func,
308 mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list) {
309 output_sharding_list->reserve(cluster_func.getNumResults());
310
311 const auto output_sharding_attrs =
312 cluster_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
313 kOutputShardingAttr);
314 if (!output_sharding_attrs)
315 return cluster_func.emitError(
316 "output_sharding_configuration missing from cluster func");
317
318 if (output_sharding_attrs.size() != cluster_func.getNumResults())
319 return cluster_func.emitError("incorrect number of output sharding");
320
321 for (auto output_sharding_and_index :
322 llvm::enumerate(output_sharding_attrs)) {
323 const auto& output_sharding = output_sharding_and_index.value();
324 const int sharding_index = output_sharding_and_index.index();
325 if (!output_sharding.isa<mlir::StringAttr>())
326 return cluster_func.emitError(llvm::formatv(
327 "non-string output sharding at index {0}", sharding_index));
328
329 xla::OpSharding sharding;
330 if (!sharding.ParseFromString(
331 output_sharding.cast<mlir::StringAttr>().getValue().str()))
332 return cluster_func.emitError("incorrect sharding format for outputs");
333
334 if (sharding.type() == xla::OpSharding::OTHER &&
335 sharding.tile_assignment_devices_size() != num_cores_per_replica)
336 return cluster_func.emitError(llvm::formatv(
337 "incorrect sharding format for outputs. Number of "
338 "tiled outputs({0}) must match the number of logical "
339 "devices({1})",
340 sharding.tile_assignment_devices_size(), num_cores_per_replica));
341
342 if (sharding.type() == xla::OpSharding::MAXIMAL &&
343 ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) ||
344 (sharding.tile_assignment_devices(0) < 0)))
345 return cluster_func.emitError(llvm::formatv(
346 "incorrect sharding format for outputs. Maximal "
347 "sharding should be assigned to device id in range "
348 "[0, {0}). Currently assigned to {1}",
349 num_cores_per_replica, sharding.tile_assignment_devices(0)));
350
351 output_sharding_list->emplace_back(std::move(sharding));
352 }
353 return mlir::success();
354 }
355
356 namespace {
357
IsAssignedToLogicalDevice(const int core_id,const xla::OpSharding & sharding)358 bool IsAssignedToLogicalDevice(const int core_id,
359 const xla::OpSharding& sharding) {
360 return sharding.type() == xla::OpSharding::MAXIMAL &&
361 sharding.tile_assignment_devices(0) == core_id;
362 }
363
364 // Returns the index of the return value of region in
365 // `tf_device.parallel_execute` that represents cluster func output at
366 // index |cluster_func_output_index|. Regions of parallel_execute may
367 // have different return values depending on output sharding configuration.
LookupClusterToCoreIndex(const mlir::Location & location,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,const int core_id,const int cluster_func_output_index,int * core_output_index)368 mlir::LogicalResult LookupClusterToCoreIndex(
369 const mlir::Location& location,
370 llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
371 const int core_id, const int cluster_func_output_index,
372 int* core_output_index) {
373 *core_output_index =
374 cluster_to_core_index[core_id][cluster_func_output_index];
375 if (*core_output_index == -1) {
376 mlir::emitError(
377 location,
378 llvm::formatv("Attempted to map cluster_func output index {0} to "
379 "program assigned to core {1}. The tensor at this output "
380 "index was not assigned or sharded to this core.",
381 cluster_func_output_index, core_id));
382 return mlir::failure();
383 }
384 return mlir::success();
385 }
386
387 // Collects tile sharded outputs from a tf_device.parallel_execute to remap from
388 // the TPU computation result.
GetTileShardedOutputsToMerge(const mlir::Location & location,const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,int cluster_idx,mlir::tf_device::ParallelExecuteOp new_parallel_execute,llvm::SmallVector<mlir::Value,4> * outputs_to_merge)389 mlir::LogicalResult GetTileShardedOutputsToMerge(
390 const mlir::Location& location, const int cluster_func_output_index,
391 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
392 llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
393 int cluster_idx, mlir::tf_device::ParallelExecuteOp new_parallel_execute,
394 llvm::SmallVector<mlir::Value, 4>* outputs_to_merge) {
395 // Reorders outputs from TPUExecute op as defined by the output sharding
396 // configuration.
397 const xla::OpSharding& sharding =
398 output_sharding_config[cluster_func_output_index];
399 outputs_to_merge->reserve(sharding.tile_assignment_devices_size());
400 for (const auto logical_device_id : sharding.tile_assignment_devices()) {
401 int region_output_index;
402 auto status = LookupClusterToCoreIndex(
403 location, cluster_to_core_index, logical_device_id,
404 cluster_func_output_index, ®ion_output_index);
405 if (failed(status)) return mlir::failure();
406 const auto output_from_logical_device =
407 new_parallel_execute.GetRegionOutputs(
408 cluster_idx + logical_device_id)[region_output_index];
409 outputs_to_merge->emplace_back(output_from_logical_device);
410 }
411
412 return mlir::success();
413 }
414
415 // Merges outputs from TPU computation for tile-sharded outputs.
HandleTileShardedOutputs(const int cluster_func_output_index,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,const mlir::Location & location,mlir::Value cluster_func_output,int cluster_idx,mlir::tf_device::ParallelExecuteOp new_parallel_execute,mlir::OpBuilder * builder)416 mlir::LogicalResult HandleTileShardedOutputs(
417 const int cluster_func_output_index,
418 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
419 llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
420 const mlir::Location& location, mlir::Value cluster_func_output,
421 int cluster_idx, mlir::tf_device::ParallelExecuteOp new_parallel_execute,
422 mlir::OpBuilder* builder) {
423 // Inject concat ops after parallel_execute to merge outputs from
424 // concurrently executed computations.
425 builder->setInsertionPointAfter(new_parallel_execute);
426
427 // Reorders outputs from TPUExecute op as defined by the output sharding
428 // configuration.
429 llvm::SmallVector<mlir::Value, 4> outputs_to_merge;
430 auto status = GetTileShardedOutputsToMerge(
431 location, cluster_func_output_index, output_sharding_config,
432 cluster_to_core_index, cluster_idx, new_parallel_execute,
433 &outputs_to_merge);
434 if (failed(status)) return mlir::failure();
435
436 // Creates a tree of Concat ops that merges outputs from multiple logical
437 // devices to a single replica output.
438 const xla::OpSharding& sharding =
439 output_sharding_config[cluster_func_output_index];
440 int concat_dimension = sharding.tile_assignment_dimensions_size() - 1;
441 for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) {
442 if (num_splits == 1) {
443 --concat_dimension;
444 continue;
445 }
446
447 llvm::SmallVector<mlir::Value, 4> new_outputs;
448 new_outputs.reserve(num_splits);
449 for (int i = 0, end = outputs_to_merge.size(); i < end;
450 i = i + num_splits) {
451 mlir::TF::ConcatOp concat_op =
452 CreateConcatOp(concat_dimension, location,
453 llvm::ArrayRef<mlir::Value>{
454 outputs_to_merge.begin() + i,
455 outputs_to_merge.begin() + i + num_splits},
456 builder);
457 new_outputs.emplace_back(concat_op.getResult());
458 }
459
460 std::swap(new_outputs, outputs_to_merge);
461 --concat_dimension;
462 }
463
464 assert(outputs_to_merge.size() == 1);
465 cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]);
466 return mlir::success();
467 }
468
ValidateAndGetTiledExecuteOutputShape(const mlir::Location & location,const mlir::TensorType cluster_func_output_type,const xla::OpSharding & output_sharding,mlir::Type * tiled_logical_computation_type)469 mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape(
470 const mlir::Location& location,
471 const mlir::TensorType cluster_func_output_type,
472 const xla::OpSharding& output_sharding,
473 mlir::Type* tiled_logical_computation_type) {
474 auto new_output_shape =
475 llvm::to_vector<4>(cluster_func_output_type.getShape());
476 for (auto dimension_and_output_splits :
477 llvm::enumerate(output_sharding.tile_assignment_dimensions())) {
478 const auto dimension_index = dimension_and_output_splits.index();
479 const auto output_splits = dimension_and_output_splits.value();
480 const auto output_shape = cluster_func_output_type.getShape();
481
482 if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) {
483 *tiled_logical_computation_type = cluster_func_output_type;
484 break;
485 }
486
487 auto output_shape_at_dim =
488 cluster_func_output_type.getShape()[dimension_index];
489 if (output_shape_at_dim % output_splits != 0) {
490 mlir::emitError(
491 location,
492 llvm::formatv("incorrect output sharding received. "
493 "{0}-th dimension of the output must be "
494 "evenly divisible by {1}, got dimension "
495 "shape {2}",
496 dimension_index, output_splits, output_shape_at_dim));
497 }
498
499 new_output_shape[dimension_index] =
500 output_shape[dimension_index] / output_splits;
501 }
502
503 *tiled_logical_computation_type = mlir::RankedTensorType::get(
504 new_output_shape, cluster_func_output_type.getElementType());
505
506 return mlir::success();
507 }
508
509 } // namespace
510
GetOutputTypesForLogicalDeviceComputation(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,mlir::tf_device::ClusterFuncOp cluster_func,llvm::SmallVectorImpl<mlir::Type> * output_types,llvm::SmallVectorImpl<int> * cluster_to_core_index)511 mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation(
512 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
513 mlir::tf_device::ClusterFuncOp cluster_func,
514 llvm::SmallVectorImpl<mlir::Type>* output_types,
515 llvm::SmallVectorImpl<int>* cluster_to_core_index) {
516 output_types->reserve(cluster_func.getNumResults());
517
518 int core_index = 0;
519 for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) {
520 const auto output_index = result_and_index.index();
521 const auto& output_sharding = output_sharding_config[output_index];
522 const auto output_sharding_type = output_sharding.type();
523 const auto cluster_func_output_type =
524 result_and_index.value().getType().cast<mlir::TensorType>();
525
526 // If output shape of cluster func is statically known and output is tiled
527 // sharded, then the corresponding output shape of cluster func must be
528 // evenly divisible number of shardings.
529 if (output_sharding_type == xla::OpSharding::OTHER) {
530 mlir::Type tiled_logical_computation_type;
531 if (cluster_func_output_type.hasRank()) {
532 auto result = ValidateAndGetTiledExecuteOutputShape(
533 cluster_func.getLoc(), cluster_func_output_type, output_sharding,
534 &tiled_logical_computation_type);
535 if (mlir::failed(result)) return mlir::failure();
536 } else {
537 tiled_logical_computation_type = cluster_func_output_type;
538 }
539 cluster_to_core_index->emplace_back(core_index++);
540 output_types->emplace_back(tiled_logical_computation_type);
541 } else if (output_sharding_type == xla::OpSharding::REPLICATED ||
542 IsAssignedToLogicalDevice(core_id, output_sharding)) {
543 cluster_to_core_index->emplace_back(core_index++);
544 output_types->emplace_back(cluster_func_output_type);
545 } else {
546 cluster_to_core_index->emplace_back(-1);
547 }
548 }
549
550 return mlir::success();
551 }
552
RemapOutputsFromLogicalDevices(const mlir::Location & location,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVector<llvm::SmallVector<int,4>,4> cluster_to_core_index,mlir::tf_device::ParallelExecuteOp old_parallel_execute,int cluster_idx,mlir::tf_device::ParallelExecuteOp new_parallel_execute,mlir::OpBuilder * builder)553 mlir::LogicalResult RemapOutputsFromLogicalDevices(
554 const mlir::Location& location,
555 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
556 llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index,
557 mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx,
558 mlir::tf_device::ParallelExecuteOp new_parallel_execute,
559 mlir::OpBuilder* builder) {
560 for (auto& result_and_index :
561 llvm::enumerate(old_parallel_execute.getResults())) {
562 const auto output_index = result_and_index.index();
563 const auto old_parallel_execute_output = result_and_index.value();
564 const auto& output_sharding = output_sharding_config[output_index];
565 const auto output_sharding_type = output_sharding.type();
566
567 // If output is demultiplexed using the `tf.TPUPartitionedOutput` op, only
568 // replicated sharding is supported where i-th output of
569 // `tf.TPUPartitionedOutput` op maps to the output of i-th logical device.
570 // Also `tf.TPUPartitionedOutput` op must be a unique user of
571 // TPU Cluster (`tf_device.old_parallel_execute`) output.
572 mlir::TF::TPUPartitionedOutputOp partitioned_output;
573 for (auto user : old_parallel_execute_output.getUsers()) {
574 if (auto partitioned_output_user =
575 llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedOutputOp>(user)) {
576 partitioned_output = partitioned_output_user;
577 break;
578 }
579 }
580 if (partitioned_output) {
581 if (!old_parallel_execute_output.hasOneUse())
582 return partitioned_output.emitOpError()
583 << "must be a unique user of TPU Cluster "
584 "(tf_device.old_parallel_execute) output "
585 << *old_parallel_execute_output.getOwner();
586 if (UnsupportedPartitionedShardingType(output_sharding_type))
587 return old_parallel_execute.emitOpError()
588 << "unsupported output sharding type "
589 << OpSharding_Type_Name(output_sharding_type) << " for "
590 << output_index << "-th output";
591
592 if (output_sharding_type == xla::OpSharding::REPLICATED) {
593 for (auto index_and_output :
594 llvm::enumerate(partitioned_output.output())) {
595 const auto output_from_logical_device =
596 new_parallel_execute.GetRegionOutputs(
597 cluster_idx + index_and_output.index())[output_index];
598 index_and_output.value().replaceAllUsesWith(
599 output_from_logical_device);
600 }
601 } else {
602 assert(output_sharding_type == xla::OpSharding::OTHER);
603 llvm::SmallVector<mlir::Value, 4> tile_sharded_outputs;
604 if (failed(GetTileShardedOutputsToMerge(
605 location, output_index, output_sharding_config,
606 cluster_to_core_index, cluster_idx, new_parallel_execute,
607 &tile_sharded_outputs)))
608 return mlir::failure();
609 for (auto result :
610 llvm::zip(partitioned_output.output(), tile_sharded_outputs))
611 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
612 }
613 continue;
614 }
615
616 if (output_sharding_type == xla::OpSharding::OTHER) {
617 if (failed(HandleTileShardedOutputs(
618 output_index, output_sharding_config, cluster_to_core_index,
619 location, old_parallel_execute_output, cluster_idx,
620 new_parallel_execute, builder)))
621 return mlir::failure();
622 continue;
623 }
624
625 int logical_device_id = 0;
626 if (output_sharding_type == xla::OpSharding::MAXIMAL)
627 logical_device_id = output_sharding.tile_assignment_devices(0);
628
629 // For maximal sharding configuration, correctly remap outputs from
630 // parallel_execute region to users of the cluster func.
631 int region_output_index;
632 if (failed(LookupClusterToCoreIndex(location, cluster_to_core_index,
633 logical_device_id, output_index,
634 ®ion_output_index)))
635 return mlir::failure();
636
637 const auto output_from_logical_device =
638 new_parallel_execute.GetRegionOutputs(
639 cluster_idx + logical_device_id)[region_output_index];
640 old_parallel_execute_output.replaceAllUsesWith(output_from_logical_device);
641 }
642 return mlir::success();
643 }
644
GetMetadataArgumentMapping(const tpu::TPUCompileMetadataProto & metadata)645 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> GetMetadataArgumentMapping(
646 const tpu::TPUCompileMetadataProto& metadata) {
647 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings(
648 metadata.num_cores_per_replica(), llvm::SmallVector<int64_t, 4>());
649
650 if (metadata.num_cores_per_replica() == 1) {
651 input_mappings.front().resize(metadata.args_size());
652 std::iota(input_mappings.front().begin(), input_mappings.front().end(), 0);
653 return input_mappings;
654 }
655
656 for (const auto& arg_and_idx : llvm::enumerate(metadata.args())) {
657 const auto& sharding = arg_and_idx.value().sharding();
658 const int64_t idx = arg_and_idx.index();
659
660 const auto sharding_type = sharding.type();
661 if (sharding_type == xla::OpSharding::OTHER) {
662 for (const auto& device : sharding.tile_assignment_devices())
663 input_mappings[device].push_back(idx);
664 } else if (sharding_type == xla::OpSharding::REPLICATED) {
665 for (auto& input : input_mappings) input.push_back(idx);
666 } else {
667 assert(sharding_type == xla::OpSharding::MAXIMAL);
668 input_mappings[sharding.tile_assignment_devices(0)].push_back(idx);
669 }
670 }
671
672 return input_mappings;
673 }
674
675 } // namespace tensorflow
676