1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17
18 #include <string>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/overflow_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/errors.h"
29
30 namespace xla {
31
32 using absl::StrCat;
33 using absl::StrJoin;
34
AssignDevice(int64_t device_id,absl::Span<const OpMetadata> metadata)35 HloSharding HloSharding::AssignDevice(int64_t device_id,
36 absl::Span<const OpMetadata> metadata) {
37 return HloSharding(device_id, metadata);
38 }
39
Tile1D(const Shape & input_shape,int64_t num_tiles,absl::Span<const OpMetadata> metadata)40 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64_t num_tiles,
41 absl::Span<const OpMetadata> metadata) {
42 CHECK_EQ(1, input_shape.rank());
43 CHECK_GT(num_tiles, 1);
44 std::vector<int64> dimensions(1, num_tiles);
45 Array<int64> assignment(dimensions);
46 std::iota(assignment.begin(), assignment.end(), 0);
47 return HloSharding(assignment, /*replicate_on_last_tile_dim=*/false,
48 metadata);
49 }
50
PartialTile(const Array<int64> & group_tile_assignment,absl::Span<const absl::Span<const int64>> replication_groups,absl::Span<const OpMetadata> metadata)51 HloSharding HloSharding::PartialTile(
52 const Array<int64>& group_tile_assignment,
53 absl::Span<const absl::Span<const int64>> replication_groups,
54 absl::Span<const OpMetadata> metadata) {
55 CHECK_EQ(group_tile_assignment.num_elements(), replication_groups.size());
56 if (replication_groups.size() == 1) {
57 return Replicate(metadata);
58 }
59 auto new_tile_dims = group_tile_assignment.dimensions();
60 new_tile_dims.push_back(replication_groups[0].size());
61 auto new_tile_assignment = Array<int64>(new_tile_dims);
62 new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
63 std::vector<int64> group_index(indices.begin(), indices.end());
64 group_index.pop_back();
65 int64_t group = group_tile_assignment(group_index);
66 *device = replication_groups[group][indices.back()];
67 });
68 return PartialTile(new_tile_assignment, metadata);
69 }
70
PartialTile(const Array<int64> & tile_assignment_last_dim_replicate,absl::Span<const OpMetadata> metadata)71 HloSharding HloSharding::PartialTile(
72 const Array<int64>& tile_assignment_last_dim_replicate,
73 absl::Span<const OpMetadata> metadata) {
74 if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
75 tile_assignment_last_dim_replicate.dimensions().back() ==
76 tile_assignment_last_dim_replicate.num_elements()) {
77 return Replicate(metadata);
78 }
79 if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
80 auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
81 new_tile_dims.pop_back();
82 auto fully_tiled = tile_assignment_last_dim_replicate;
83 fully_tiled.Reshape(new_tile_dims);
84 return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false,
85 metadata);
86 }
87 std::vector<std::set<int64>> sorted_groups(
88 tile_assignment_last_dim_replicate.num_elements() /
89 tile_assignment_last_dim_replicate.dimensions().back());
90 auto get_group_id = [&](absl::Span<const int64> indices) {
91 int64_t group_id = 0;
92 for (int64_t i = 0; i < indices.size() - 1; ++i) {
93 group_id *= tile_assignment_last_dim_replicate.dim(i);
94 group_id += indices[i];
95 }
96 return group_id;
97 };
98 tile_assignment_last_dim_replicate.Each(
99 [&](absl::Span<const int64> indices, const int64_t device) {
100 sorted_groups[get_group_id(indices)].insert(device);
101 });
102 Array<int64> sorted_tile(tile_assignment_last_dim_replicate.dimensions());
103 sorted_tile.Each([&](absl::Span<const int64> indices, int64* device) {
104 const int64_t group_id = get_group_id(indices);
105 auto begin = sorted_groups[group_id].begin();
106 *device = *begin;
107 sorted_groups[group_id].erase(begin);
108 });
109 return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true,
110 metadata);
111 }
112
Subgroup(const Array<int64> & tile_assignment,absl::Span<const OpSharding::Type> sharding_types,absl::Span<const OpMetadata> metadata)113 HloSharding HloSharding::Subgroup(
114 const Array<int64>& tile_assignment,
115 absl::Span<const OpSharding::Type> sharding_types,
116 absl::Span<const OpMetadata> metadata) {
117 if (sharding_types.empty()) {
118 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
119 metadata);
120 }
121
122 CHECK_GT(tile_assignment.dimensions().size(), sharding_types.size());
123 int64 num_group_dims = sharding_types.size();
124 std::vector<int64> group_dimensions(
125 tile_assignment.dimensions().end() - num_group_dims,
126 tile_assignment.dimensions().end());
127 int64 num_groups = Product(group_dimensions);
128 std::vector<std::set<int64>> sorted_groups(num_groups);
129 auto get_group_id = [&](absl::Span<const int64> indices) {
130 int64_t group_id = 0;
131 for (int64_t i = 0; i < indices.size() - num_group_dims; ++i) {
132 group_id *= tile_assignment.dim(i);
133 group_id += indices[i];
134 }
135 return group_id;
136 };
137 tile_assignment.Each(
138 [&](absl::Span<const int64> indices, const int64_t device) {
139 sorted_groups[get_group_id(indices)].insert(device);
140 });
141 Array<int64> sorted_tile(tile_assignment.dimensions());
142 sorted_tile.Each([&](absl::Span<const int64> indices, int64* device) {
143 const int64_t group_id = get_group_id(indices);
144 auto begin = sorted_groups[group_id].begin();
145 *device = *begin;
146 sorted_groups[group_id].erase(begin);
147 });
148
149 return HloSharding(sorted_tile, sharding_types, metadata);
150 }
151
Tuple(const ShapeTree<HloSharding> & sub_shardings)152 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
153 std::vector<HloSharding> flattened_list;
154 flattened_list.reserve(sub_shardings.leaf_count());
155 for (const auto& index_to_sharding : sub_shardings.leaves()) {
156 flattened_list.push_back(index_to_sharding.second);
157 }
158 if (flattened_list.empty()) {
159 // Empty tuple sharding ends up having no leaves, but we want to allow
160 // empty tuple HLO instruction results to have sharding, so we fetch the
161 // root ({}) sharding value from the ShapeTree.
162 // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
163 // init as value at its root.
164 flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
165 }
166 return HloSharding(flattened_list);
167 }
168
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)169 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
170 absl::Span<const HloSharding> shardings) {
171 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
172 for (auto& sharding : shardings) {
173 CHECK(!sharding.IsTuple())
174 << sharding.ToString() << ShapeUtil::HumanString(tuple_shape);
175 }
176 std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
177 CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
178 << "Flat list has " << flattened_list.size() << ", required "
179 << RequiredLeaves(tuple_shape);
180 return HloSharding(flattened_list);
181 }
182
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)183 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
184 const HloSharding& sharding) {
185 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
186 CHECK(!sharding.IsTuple()) << sharding.ToString();
187 int64_t leaf_count = RequiredLeaves(tuple_shape);
188 std::vector<HloSharding> flattened_list;
189 flattened_list.resize(leaf_count, sharding);
190 return HloSharding(flattened_list);
191 }
192
Single(const Shape & shape,const HloSharding & sharding)193 HloSharding HloSharding::Single(const Shape& shape,
194 const HloSharding& sharding) {
195 return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
196 }
197
ToString(bool include_metadata) const198 string HloSharding::ToString(bool include_metadata) const {
199 if (IsTuple()) {
200 CHECK(metadata_.empty());
201 std::string result = "{";
202 for (int i = 0; i < tuple_elements_.size(); ++i) {
203 const HloSharding& element = tuple_elements_[i];
204 if (i != 0) {
205 absl::StrAppend(&result, ", ");
206 if (i % 5 == 0) {
207 absl::StrAppend(&result, "/*index=", i, "*/");
208 }
209 }
210 absl::StrAppend(&result, element.ToString(include_metadata));
211 }
212 absl::StrAppend(&result, "}");
213 return result;
214 }
215
216 std::string metadata;
217 if (include_metadata) {
218 if (metadata_.size() == 1) {
219 metadata =
220 StrCat(" metadata={", OpMetadataToString(metadata_.front()), "}");
221 } else if (metadata_.size() > 1) {
222 std::vector<std::string> metadata_strings;
223 metadata_strings.reserve(metadata_.size());
224 for (const auto& single_metadata : metadata_) {
225 metadata_strings.push_back(
226 StrCat("{", OpMetadataToString(single_metadata), "}"));
227 }
228 metadata = StrCat(" metadata={", StrJoin(metadata_strings, ", "), "}");
229 }
230 }
231
232 std::string last_tile_dims;
233 if (!sharding_types_.empty()) {
234 auto op_sharding_type_to_string = [](OpSharding::Type type) {
235 switch (type) {
236 case OpSharding::MANUAL:
237 return "manual";
238 case OpSharding::MAXIMAL:
239 return "maximul";
240 case OpSharding::REPLICATED:
241 return "replicated";
242 default:
243 return "error_type.";
244 }
245 };
246 std::vector<std::string> sharding_type_strings;
247 sharding_type_strings.reserve(sharding_types_.size());
248 for (const auto& single_sharding_type : sharding_types_) {
249 sharding_type_strings.push_back(
250 op_sharding_type_to_string(single_sharding_type));
251 }
252 last_tile_dims =
253 StrCat(" last_tile_dims={", StrJoin(sharding_type_strings, ", "), "}");
254 }
255
256 if (replicated_) {
257 return StrCat("{replicated", metadata, "}");
258 }
259
260 if (manual_) {
261 return StrCat("{manual", metadata, "}");
262 }
263 if (maximal_) {
264 return StrCat("{maximal device=",
265 static_cast<int64>(*tile_assignment_.begin()), metadata, "}");
266 }
267 return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
268 StrJoin(tile_assignment_, ","),
269 replicate_on_last_tile_dim_ ? " last_tile_dim_replicate" : "",
270 last_tile_dims, metadata, "}");
271 }
272
UsesDevice(int64_t device) const273 bool HloSharding::UsesDevice(int64_t device) const {
274 if (IsTuple()) {
275 return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
276 return s.UsesDevice(device);
277 });
278 }
279 const auto& devices = tile_assignment_;
280 return replicated_ || manual_ || absl::c_linear_search(devices, device);
281 }
282
UsedDevices(int64 * count) const283 std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
284 int64_t element_count = 1;
285 std::map<int64, int64> device_map;
286 if (IsTuple()) {
287 for (auto& tuple_element_sharding : tuple_elements()) {
288 auto unique_device = tuple_element_sharding.UniqueDevice();
289 if (unique_device) {
290 device_map[*unique_device] += 1;
291 }
292 }
293 element_count = tuple_elements().size();
294 } else {
295 auto unique_device = UniqueDevice();
296 if (unique_device) {
297 device_map[*unique_device] += 1;
298 }
299 }
300 if (count != nullptr) {
301 *count = element_count;
302 }
303 return device_map;
304 }
305
TileIndexForDevice(int64_t device) const306 std::vector<int64> HloSharding::TileIndexForDevice(int64_t device) const {
307 CHECK(!maximal_);
308 CHECK(!manual_);
309 CHECK(!IsTuple());
310 std::vector<int64> ret_index;
311 tile_assignment_.Each([&](absl::Span<const int64> index, int64_t d) {
312 if (d == device) {
313 ret_index = {index.begin(), index.end()};
314 }
315 });
316 CHECK(!ret_index.empty());
317 if (replicate_on_last_tile_dim_) {
318 ret_index.pop_back();
319 }
320 return ret_index;
321 }
322
DeviceForTileIndex(absl::Span<const int64> index) const323 int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
324 CHECK(!replicated_);
325 CHECK(!manual_);
326 CHECK(!IsTuple());
327 if (maximal_) {
328 return *tile_assignment_.begin();
329 }
330 if (replicate_on_last_tile_dim_ &&
331 index.size() < tile_assignment().num_dimensions()) {
332 std::vector<int64> first_replicated_index(index.begin(), index.end());
333 first_replicated_index.push_back(0);
334 return tile_assignment_(first_replicated_index);
335 }
336 return tile_assignment_(index);
337 }
338
TileOffsetForDevice(const Shape & shape,int64_t device) const339 std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
340 int64_t device) const {
341 CHECK(!IsTuple());
342 CHECK(!manual_);
343
344 if (maximal_) {
345 return std::vector<int64>(shape.dimensions_size(), 0);
346 }
347 if (replicate_on_last_tile_dim_) {
348 CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions() - 1);
349 } else {
350 CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
351 }
352 std::vector<int64> index = TileIndexForDevice(device);
353 for (int64_t i = 0; i < index.size(); ++i) {
354 const int64_t shape_dim = shape.dimensions(i);
355 index[i] = std::min(
356 index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
357 }
358 return index;
359 }
360
TileLimitForDevice(const Shape & shape,int64_t device) const361 std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
362 int64_t device) const {
363 CHECK(!IsTuple());
364 CHECK(!manual_);
365
366 if (maximal_) {
367 return std::vector<int64>(shape.dimensions().begin(),
368 shape.dimensions().end());
369 }
370
371 CHECK_EQ(shape.dimensions_size() + (ReplicateOnLastTileDim() ? 1 : 0),
372 tile_assignment_.num_dimensions());
373 std::vector<int64> index = TileIndexForDevice(device);
374 for (int64_t i = 0; i < index.size(); ++i) {
375 const int64_t shape_dim = shape.dimensions(i);
376 index[i] = std::min(
377 (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
378 shape_dim);
379 }
380 return index;
381 }
382
RequiredLeaves(const Shape & shape)383 int64 HloSharding::RequiredLeaves(const Shape& shape) {
384 // Empty tuples (with arbitrary nesting) have no leaf nodes as far as
385 // ShapeUtil and ShapeTree are concerned, but they do have a single
386 // tuple_elements_ entry since we want to allow empty tuple results to
387 // have sharding.
388 const int64_t leaf_count = ShapeUtil::GetLeafCount(shape);
389 return (leaf_count == 0) ? 1 : leaf_count;
390 }
391
CheckLeafCount(const Shape & shape) const392 Status HloSharding::CheckLeafCount(const Shape& shape) const {
393 int64_t shape_leaves = RequiredLeaves(shape);
394 TF_RET_CHECK(shape_leaves == tuple_elements_.size())
395 << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
396 << " leaf nodes while this sharding has " << tuple_elements_.size();
397 return Status::OK();
398 }
399
AsShapeTree(const Shape & shape) const400 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
401 const Shape& shape) const {
402 if (IsTuple()) {
403 ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
404 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
405 auto it = tuple_elements_.begin();
406 for (auto& index_to_sharding : result.leaves()) {
407 index_to_sharding.second = *it++;
408 }
409 if (ShapeUtil::IsEmptyTuple(shape)) {
410 // Empty tuples have no leaves, but we want to assign them a sharding
411 // anyway, so we use the root element sharding.
412 *result.mutable_element(ShapeIndex({})) = *it;
413 }
414 return std::move(result);
415 } else {
416 return ShapeTree<HloSharding>(shape, *this);
417 }
418 }
419
GetTupleSharding(const Shape & shape) const420 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
421 if (IsTuple()) {
422 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
423 return *this;
424 }
425 return Tuple(ShapeTree<HloSharding>(shape, *this));
426 }
427
UniqueDevice() const428 absl::optional<int64> HloSharding::UniqueDevice() const {
429 if (IsTuple()) {
430 if (tuple_elements_.empty()) {
431 return absl::nullopt;
432 }
433 absl::optional<int64> unique_device;
434 for (auto& tuple_sharding : tuple_elements_) {
435 auto device = tuple_sharding.UniqueDevice();
436 if (!device || (unique_device && *device != *unique_device)) {
437 return absl::nullopt;
438 }
439 unique_device = device;
440 }
441 return unique_device;
442 }
443 if (!replicated_ && maximal_) {
444 return static_cast<int64>(*tile_assignment_.begin());
445 }
446 return absl::nullopt;
447 }
448
GetUniqueDevice() const449 int64 HloSharding::GetUniqueDevice() const {
450 auto device = UniqueDevice();
451 CHECK(device) << "Sharding does not have a unique device: " << *this;
452 return *device;
453 }
454
ValidateTuple(const Shape & shape,int64_t num_devices) const455 Status HloSharding::ValidateTuple(const Shape& shape,
456 int64_t num_devices) const {
457 if (!shape.IsTuple()) {
458 return tensorflow::errors::InvalidArgument(
459 StrCat("Sharding is tuple-shaped but validation shape is not."));
460 }
461 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
462
463 // Now we've validated the number of tuple elements, it's safe to request a
464 // shape tree.
465 ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
466 for (const auto& index_to_sharding : shape_tree.leaves()) {
467 Status status = index_to_sharding.second.ValidateNonTuple(
468 ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
469 if (!status.ok()) {
470 tensorflow::errors::AppendToMessage(
471 &status, StrCat("Note: While validating sharding tuple element ",
472 index_to_sharding.first.ToString(), " which is ",
473 index_to_sharding.second.ToString()));
474 return status;
475 }
476 }
477 return Status::OK();
478 }
479
Validate(const Shape & shape,int64_t num_devices) const480 Status HloSharding::Validate(const Shape& shape, int64_t num_devices) const {
481 if (shape.IsToken()) {
482 return Status::OK();
483 }
484 Status status = IsTuple() ? ValidateTuple(shape, num_devices)
485 : ValidateNonTuple(shape, num_devices);
486 if (!status.ok()) {
487 tensorflow::errors::AppendToMessage(
488 &status, StrCat("Note: While validating sharding ", ToString(),
489 " against shape ", ShapeUtil::HumanString(shape)));
490 }
491 return status;
492 }
493
ValidateNonTuple(const Shape & shape,int64_t num_devices) const494 Status HloSharding::ValidateNonTuple(const Shape& shape,
495 int64_t num_devices) const {
496 if (shape.IsTuple()) {
497 return tensorflow::errors::InvalidArgument(
498 StrCat("Validation shape is a tuple but sharding is not."));
499 }
500 if (replicated_) {
501 return Status::OK();
502 }
503
504 // All tile assignments must be less than the number of available cores and
505 // unique.
506 Status status = Status::OK();
507 absl::flat_hash_set<int64> seen_cores;
508 tile_assignment_.Each([&](absl::Span<const int64> indices, int32_t core) {
509 // Don't overwrite a bad status, so we report the first error.
510 if (status.ok()) {
511 if (core >= num_devices) {
512 status = tensorflow::errors::InvalidArgument(
513 StrCat("core ", core, " > ", num_devices, " in tile assignment"));
514 } else if (seen_cores.contains(core)) {
515 status = tensorflow::errors::InvalidArgument(
516 StrCat("core ", core, " is not unique in tile assignment"));
517 }
518 seen_cores.insert(core);
519 }
520 });
521 if (!status.ok()) {
522 return status;
523 }
524
525 if (IsTileMaximal() || IsManual()) {
526 return Status::OK();
527 }
528
529 // The tile assignment tensor must have the same rank as the input, or input
530 // rank + 1 for replicate_on_last_tile_dim_.
531 if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) !=
532 tile_assignment_.num_dimensions()) {
533 return tensorflow::errors::InvalidArgument(
534 "Number of tile assignment dimensions is different to the input rank. "
535 "sharding=",
536 ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
537 }
538
539 // The correct constructor has to be used to create tile maximal shardings.
540 if (tile_assignment_.num_elements() == 1) {
541 return tensorflow::errors::InvalidArgument(
542 "Tile assignment only contains a single device. If a replicated "
543 "sharding was intended, use HloSharding::Replicated(). If a device "
544 "placement was intended, use HloSharding::AssignDevice()");
545 }
546 return Status::OK();
547 }
548
FromProto(const OpSharding & proto)549 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
550 const OpSharding& proto) {
551 std::vector<OpMetadata> metadata(proto.metadata().begin(),
552 proto.metadata().end());
553 std::vector<int> sharding_types_int(proto.last_tile_dims().begin(),
554 proto.last_tile_dims().end());
555 std::vector<OpSharding::Type> sharding_types;
556 absl::c_transform(
557 sharding_types_int, std::back_inserter(sharding_types),
558 [](const int type) { return static_cast<OpSharding::Type>(type); });
559 if (proto.type() == OpSharding::TUPLE) {
560 TF_RET_CHECK(metadata.empty())
561 << "Tuple sharding is expected to have no metadata.";
562 std::vector<HloSharding> tuple_shardings;
563 tuple_shardings.reserve(proto.tuple_shardings().size());
564 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
565 TF_ASSIGN_OR_RETURN(HloSharding sharding,
566 HloSharding::FromProto(tuple_sharding_proto));
567 tuple_shardings.push_back(sharding);
568 }
569 return HloSharding(tuple_shardings);
570 } else if (proto.type() == OpSharding::REPLICATED) {
571 return Replicate(metadata);
572 } else if (proto.type() == OpSharding::MANUAL) {
573 return Manual(metadata);
574 } else if (proto.tile_assignment_devices().size() == 1) {
575 return HloSharding(proto.tile_assignment_devices(0), metadata);
576 }
577
578 TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
579 << "Maximal sharding is expected to have single device assignment, but "
580 << proto.tile_assignment_devices().size() << " has provided.";
581
582 TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
583 TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
584
585 // RE: the product of tile assignment tensor dimensions must be
586 // equal to tile_assignment_devices.size().
587 int64_t product_of_dimensions = 1;
588 for (auto dimension : proto.tile_assignment_dimensions()) {
589 TF_RET_CHECK(dimension > 0);
590 product_of_dimensions =
591 MultiplyWithoutOverflow(product_of_dimensions, dimension);
592 TF_RET_CHECK(product_of_dimensions > 0);
593 }
594 TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
595
596 // Some versions of gcc cannot infer the TileAssignment constructor from a
597 // braced initializer-list, so create one manually.
598 std::vector<int64> devices(proto.tile_assignment_devices().begin(),
599 proto.tile_assignment_devices().end());
600 Array<int64> tile_assignment(
601 std::vector<int64>(proto.tile_assignment_dimensions().begin(),
602 proto.tile_assignment_dimensions().end()));
603 std::copy(proto.tile_assignment_devices().begin(),
604 proto.tile_assignment_devices().end(), tile_assignment.begin());
605 if (!sharding_types.empty()) {
606 TF_RET_CHECK(!proto.replicate_on_last_tile_dim());
607 return HloSharding(tile_assignment, sharding_types, metadata);
608 }
609 return proto.replicate_on_last_tile_dim()
610 ? PartialTile(tile_assignment, metadata)
611 : HloSharding(tile_assignment,
612 /*replicate_on_last_tile_dim=*/false, metadata);
613 }
614
ToProto() const615 OpSharding HloSharding::ToProto() const {
616 OpSharding result;
617
618 if (IsTuple()) {
619 CHECK(metadata_.empty());
620 for (const HloSharding& element : tuple_elements_) {
621 *result.add_tuple_shardings() = element.ToProto();
622 }
623 result.set_type(OpSharding::TUPLE);
624 return result;
625 }
626
627 result.mutable_metadata()->Reserve(metadata_.size());
628 for (const auto& metadata : metadata_) {
629 *result.add_metadata() = metadata;
630 }
631
632 for (int64_t dim : tile_assignment_.dimensions()) {
633 result.add_tile_assignment_dimensions(dim);
634 }
635 for (auto device : tile_assignment_) {
636 result.add_tile_assignment_devices(device);
637 }
638 if (IsReplicated()) {
639 result.set_type(OpSharding::REPLICATED);
640 result.clear_tile_assignment_dimensions();
641 } else if (IsTileMaximal()) {
642 result.set_type(OpSharding::MAXIMAL);
643 } else if (IsManual()) {
644 result.set_type(OpSharding::MANUAL);
645 result.clear_tile_assignment_dimensions();
646 } else {
647 result.set_type(OpSharding::OTHER);
648 result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
649 }
650 return result;
651 }
652
TileShape(const Shape & shape) const653 Shape HloSharding::TileShape(const Shape& shape) const {
654 if (IsTileMaximal() || IsManual()) {
655 return shape;
656 }
657 Shape result_shape = shape;
658 for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
659 result_shape.set_dimensions(
660 i, CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)));
661 }
662 return result_shape;
663 }
664
TileShape(const Shape & shape,int64_t device) const665 Shape HloSharding::TileShape(const Shape& shape, int64_t device) const {
666 if (IsTileMaximal() || IsManual()) {
667 return shape;
668 }
669
670 std::vector<int64> index = TileIndexForDevice(device);
671 Shape result_shape = shape;
672 for (int64_t i = 0; i < index.size(); ++i) {
673 const int64_t shape_dim = shape.dimensions(i);
674 int64_t offset = std::min(
675 index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
676 int64_t limit = std::min(
677 (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
678 shape_dim);
679 result_shape.set_dimensions(i, limit - offset);
680 }
681 return result_shape;
682 }
683
NumTiles() const684 int64 HloSharding::NumTiles() const {
685 if (IsTileMaximal()) {
686 return 1;
687 }
688 CHECK(!IsManual());
689 if (ReplicateOnLastTileDim()) {
690 return tile_assignment().num_elements() /
691 tile_assignment().dimensions().back();
692 }
693 return tile_assignment().num_elements();
694 }
695
NumTiles(absl::Span<const int64> dims) const696 int64 HloSharding::NumTiles(absl::Span<const int64> dims) const {
697 if (IsTileMaximal()) {
698 return 1;
699 }
700 CHECK(!IsManual());
701 CHECK(!ReplicateOnLastTileDim() ||
702 !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1));
703 int64_t num_tiles = 1;
704 for (auto d : dims) {
705 CHECK(d < tile_assignment().num_dimensions());
706 num_tiles *= tile_assignment().dim(d);
707 }
708 return num_tiles;
709 }
710
GetSubSharding(const Shape & shape,const ShapeIndex & index) const711 HloSharding HloSharding::GetSubSharding(const Shape& shape,
712 const ShapeIndex& index) const {
713 CHECK(IsTuple());
714 int64_t sharding_index = 0;
715 const Shape* sub_shape = &shape;
716 for (int64_t idx : index) {
717 for (int64_t i = 0; i < idx; ++i) {
718 sharding_index +=
719 ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
720 }
721 sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
722 }
723 if (sub_shape->IsTuple()) {
724 auto begin_it = tuple_elements_.begin() + sharding_index;
725 std::vector<HloSharding> sub_shardings(
726 begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
727 return HloSharding::Tuple(*sub_shape, sub_shardings);
728 } else {
729 return tuple_elements_[sharding_index];
730 }
731 }
732
ExtractSingleSharding() const733 absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
734 if (!IsTuple()) {
735 return *this;
736 }
737 if (tuple_elements_.empty()) {
738 return absl::nullopt;
739 }
740 for (int64_t i = 1; i < tuple_elements_.size(); ++i) {
741 if (tuple_elements_[0] != tuple_elements_[i]) {
742 return absl::nullopt;
743 }
744 }
745 return tuple_elements_.front();
746 }
747
WithMetadata(absl::Span<const OpMetadata> metadata,bool overwrite) const748 HloSharding HloSharding::WithMetadata(absl::Span<const OpMetadata> metadata,
749 bool overwrite) const {
750 auto assign_metadata = [&](HloSharding& sharding) {
751 if (sharding.metadata_.empty() || overwrite) {
752 sharding.metadata_.assign(metadata.begin(), metadata.end());
753 }
754 };
755
756 HloSharding sharding = *this;
757 if (sharding.IsTuple()) {
758 for (HloSharding& sub_sharding : sharding.tuple_elements()) {
759 assign_metadata(sub_sharding);
760 }
761 } else {
762 assign_metadata(sharding);
763 }
764 return sharding;
765 }
766
WithoutMetadata() const767 HloSharding HloSharding::WithoutMetadata() const {
768 HloSharding sharding = *this;
769 sharding.metadata_.clear();
770 for (HloSharding& sub_sharding : sharding.tuple_elements()) {
771 sub_sharding.metadata_.clear();
772 }
773 return sharding;
774 }
775
Hash() const776 size_t HloSharding::Hash() const {
777 if (tuple_) {
778 size_t h = 0;
779 for (const auto& element : tuple_elements_) {
780 h = tensorflow::Hash64Combine(h, element.Hash());
781 }
782 return h;
783 }
784 if (replicated_) {
785 return 0;
786 }
787 if (manual_) {
788 return 1;
789 }
790 size_t h = 0;
791 for (uint32 v : tile_assignment_) {
792 h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
793 }
794 if (replicate_on_last_tile_dim_) {
795 h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(1));
796 }
797 return h;
798 }
799
operator <<(std::ostream & out,const HloSharding & sharding)800 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
801 out << sharding.ToString();
802 return out;
803 }
804
805 } // namespace xla
806