• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17 
18 #include <string>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/overflow_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 
30 namespace xla {
31 
32 using absl::StrCat;
33 using absl::StrJoin;
34 
AssignDevice(int64_t device_id,absl::Span<const OpMetadata> metadata)35 HloSharding HloSharding::AssignDevice(int64_t device_id,
36                                       absl::Span<const OpMetadata> metadata) {
37   return HloSharding(device_id, metadata);
38 }
39 
Tile1D(const Shape & input_shape,int64_t num_tiles,absl::Span<const OpMetadata> metadata)40 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64_t num_tiles,
41                                 absl::Span<const OpMetadata> metadata) {
42   CHECK_EQ(1, input_shape.rank());
43   CHECK_GT(num_tiles, 1);
44   std::vector<int64> dimensions(1, num_tiles);
45   Array<int64> assignment(dimensions);
46   std::iota(assignment.begin(), assignment.end(), 0);
47   return HloSharding(assignment, /*replicate_on_last_tile_dim=*/false,
48                      metadata);
49 }
50 
PartialTile(const Array<int64> & group_tile_assignment,absl::Span<const absl::Span<const int64>> replication_groups,absl::Span<const OpMetadata> metadata)51 HloSharding HloSharding::PartialTile(
52     const Array<int64>& group_tile_assignment,
53     absl::Span<const absl::Span<const int64>> replication_groups,
54     absl::Span<const OpMetadata> metadata) {
55   CHECK_EQ(group_tile_assignment.num_elements(), replication_groups.size());
56   if (replication_groups.size() == 1) {
57     return Replicate(metadata);
58   }
59   auto new_tile_dims = group_tile_assignment.dimensions();
60   new_tile_dims.push_back(replication_groups[0].size());
61   auto new_tile_assignment = Array<int64>(new_tile_dims);
62   new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
63     std::vector<int64> group_index(indices.begin(), indices.end());
64     group_index.pop_back();
65     int64_t group = group_tile_assignment(group_index);
66     *device = replication_groups[group][indices.back()];
67   });
68   return PartialTile(new_tile_assignment, metadata);
69 }
70 
PartialTile(const Array<int64> & tile_assignment_last_dim_replicate,absl::Span<const OpMetadata> metadata)71 HloSharding HloSharding::PartialTile(
72     const Array<int64>& tile_assignment_last_dim_replicate,
73     absl::Span<const OpMetadata> metadata) {
74   if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
75       tile_assignment_last_dim_replicate.dimensions().back() ==
76           tile_assignment_last_dim_replicate.num_elements()) {
77     return Replicate(metadata);
78   }
79   if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
80     auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
81     new_tile_dims.pop_back();
82     auto fully_tiled = tile_assignment_last_dim_replicate;
83     fully_tiled.Reshape(new_tile_dims);
84     return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false,
85                        metadata);
86   }
87   std::vector<std::set<int64>> sorted_groups(
88       tile_assignment_last_dim_replicate.num_elements() /
89       tile_assignment_last_dim_replicate.dimensions().back());
90   auto get_group_id = [&](absl::Span<const int64> indices) {
91     int64_t group_id = 0;
92     for (int64_t i = 0; i < indices.size() - 1; ++i) {
93       group_id *= tile_assignment_last_dim_replicate.dim(i);
94       group_id += indices[i];
95     }
96     return group_id;
97   };
98   tile_assignment_last_dim_replicate.Each(
99       [&](absl::Span<const int64> indices, const int64_t device) {
100         sorted_groups[get_group_id(indices)].insert(device);
101       });
102   Array<int64> sorted_tile(tile_assignment_last_dim_replicate.dimensions());
103   sorted_tile.Each([&](absl::Span<const int64> indices, int64* device) {
104     const int64_t group_id = get_group_id(indices);
105     auto begin = sorted_groups[group_id].begin();
106     *device = *begin;
107     sorted_groups[group_id].erase(begin);
108   });
109   return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true,
110                      metadata);
111 }
112 
Subgroup(const Array<int64> & tile_assignment,absl::Span<const OpSharding::Type> sharding_types,absl::Span<const OpMetadata> metadata)113 HloSharding HloSharding::Subgroup(
114     const Array<int64>& tile_assignment,
115     absl::Span<const OpSharding::Type> sharding_types,
116     absl::Span<const OpMetadata> metadata) {
117   if (sharding_types.empty()) {
118     return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
119                        metadata);
120   }
121 
122   CHECK_GT(tile_assignment.dimensions().size(), sharding_types.size());
123   int64 num_group_dims = sharding_types.size();
124   std::vector<int64> group_dimensions(
125       tile_assignment.dimensions().end() - num_group_dims,
126       tile_assignment.dimensions().end());
127   int64 num_groups = Product(group_dimensions);
128   std::vector<std::set<int64>> sorted_groups(num_groups);
129   auto get_group_id = [&](absl::Span<const int64> indices) {
130     int64_t group_id = 0;
131     for (int64_t i = 0; i < indices.size() - num_group_dims; ++i) {
132       group_id *= tile_assignment.dim(i);
133       group_id += indices[i];
134     }
135     return group_id;
136   };
137   tile_assignment.Each(
138       [&](absl::Span<const int64> indices, const int64_t device) {
139         sorted_groups[get_group_id(indices)].insert(device);
140       });
141   Array<int64> sorted_tile(tile_assignment.dimensions());
142   sorted_tile.Each([&](absl::Span<const int64> indices, int64* device) {
143     const int64_t group_id = get_group_id(indices);
144     auto begin = sorted_groups[group_id].begin();
145     *device = *begin;
146     sorted_groups[group_id].erase(begin);
147   });
148 
149   return HloSharding(sorted_tile, sharding_types, metadata);
150 }
151 
Tuple(const ShapeTree<HloSharding> & sub_shardings)152 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
153   std::vector<HloSharding> flattened_list;
154   flattened_list.reserve(sub_shardings.leaf_count());
155   for (const auto& index_to_sharding : sub_shardings.leaves()) {
156     flattened_list.push_back(index_to_sharding.second);
157   }
158   if (flattened_list.empty()) {
159     // Empty tuple sharding ends up having no leaves, but we want to allow
160     // empty tuple HLO instruction results to have sharding, so we fetch the
161     // root ({}) sharding value from the ShapeTree.
162     // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
163     // init as value at its root.
164     flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
165   }
166   return HloSharding(flattened_list);
167 }
168 
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)169 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
170                                absl::Span<const HloSharding> shardings) {
171   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
172   for (auto& sharding : shardings) {
173     CHECK(!sharding.IsTuple())
174         << sharding.ToString() << ShapeUtil::HumanString(tuple_shape);
175   }
176   std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
177   CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
178       << "Flat list has " << flattened_list.size() << ", required "
179       << RequiredLeaves(tuple_shape);
180   return HloSharding(flattened_list);
181 }
182 
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)183 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
184                                      const HloSharding& sharding) {
185   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
186   CHECK(!sharding.IsTuple()) << sharding.ToString();
187   int64_t leaf_count = RequiredLeaves(tuple_shape);
188   std::vector<HloSharding> flattened_list;
189   flattened_list.resize(leaf_count, sharding);
190   return HloSharding(flattened_list);
191 }
192 
Single(const Shape & shape,const HloSharding & sharding)193 HloSharding HloSharding::Single(const Shape& shape,
194                                 const HloSharding& sharding) {
195   return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
196 }
197 
ToString(bool include_metadata) const198 string HloSharding::ToString(bool include_metadata) const {
199   if (IsTuple()) {
200     CHECK(metadata_.empty());
201     std::string result = "{";
202     for (int i = 0; i < tuple_elements_.size(); ++i) {
203       const HloSharding& element = tuple_elements_[i];
204       if (i != 0) {
205         absl::StrAppend(&result, ", ");
206         if (i % 5 == 0) {
207           absl::StrAppend(&result, "/*index=", i, "*/");
208         }
209       }
210       absl::StrAppend(&result, element.ToString(include_metadata));
211     }
212     absl::StrAppend(&result, "}");
213     return result;
214   }
215 
216   std::string metadata;
217   if (include_metadata) {
218     if (metadata_.size() == 1) {
219       metadata =
220           StrCat(" metadata={", OpMetadataToString(metadata_.front()), "}");
221     } else if (metadata_.size() > 1) {
222       std::vector<std::string> metadata_strings;
223       metadata_strings.reserve(metadata_.size());
224       for (const auto& single_metadata : metadata_) {
225         metadata_strings.push_back(
226             StrCat("{", OpMetadataToString(single_metadata), "}"));
227       }
228       metadata = StrCat(" metadata={", StrJoin(metadata_strings, ", "), "}");
229     }
230   }
231 
232   std::string last_tile_dims;
233   if (!sharding_types_.empty()) {
234     auto op_sharding_type_to_string = [](OpSharding::Type type) {
235       switch (type) {
236         case OpSharding::MANUAL:
237           return "manual";
238         case OpSharding::MAXIMAL:
239           return "maximul";
240         case OpSharding::REPLICATED:
241           return "replicated";
242         default:
243           return "error_type.";
244       }
245     };
246     std::vector<std::string> sharding_type_strings;
247     sharding_type_strings.reserve(sharding_types_.size());
248     for (const auto& single_sharding_type : sharding_types_) {
249       sharding_type_strings.push_back(
250           op_sharding_type_to_string(single_sharding_type));
251     }
252     last_tile_dims =
253         StrCat(" last_tile_dims={", StrJoin(sharding_type_strings, ", "), "}");
254   }
255 
256   if (replicated_) {
257     return StrCat("{replicated", metadata, "}");
258   }
259 
260   if (manual_) {
261     return StrCat("{manual", metadata, "}");
262   }
263   if (maximal_) {
264     return StrCat("{maximal device=",
265                   static_cast<int64>(*tile_assignment_.begin()), metadata, "}");
266   }
267   return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
268                 StrJoin(tile_assignment_, ","),
269                 replicate_on_last_tile_dim_ ? " last_tile_dim_replicate" : "",
270                 last_tile_dims, metadata, "}");
271 }
272 
UsesDevice(int64_t device) const273 bool HloSharding::UsesDevice(int64_t device) const {
274   if (IsTuple()) {
275     return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
276       return s.UsesDevice(device);
277     });
278   }
279   const auto& devices = tile_assignment_;
280   return replicated_ || manual_ || absl::c_linear_search(devices, device);
281 }
282 
UsedDevices(int64 * count) const283 std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
284   int64_t element_count = 1;
285   std::map<int64, int64> device_map;
286   if (IsTuple()) {
287     for (auto& tuple_element_sharding : tuple_elements()) {
288       auto unique_device = tuple_element_sharding.UniqueDevice();
289       if (unique_device) {
290         device_map[*unique_device] += 1;
291       }
292     }
293     element_count = tuple_elements().size();
294   } else {
295     auto unique_device = UniqueDevice();
296     if (unique_device) {
297       device_map[*unique_device] += 1;
298     }
299   }
300   if (count != nullptr) {
301     *count = element_count;
302   }
303   return device_map;
304 }
305 
TileIndexForDevice(int64_t device) const306 std::vector<int64> HloSharding::TileIndexForDevice(int64_t device) const {
307   CHECK(!maximal_);
308   CHECK(!manual_);
309   CHECK(!IsTuple());
310   std::vector<int64> ret_index;
311   tile_assignment_.Each([&](absl::Span<const int64> index, int64_t d) {
312     if (d == device) {
313       ret_index = {index.begin(), index.end()};
314     }
315   });
316   CHECK(!ret_index.empty());
317   if (replicate_on_last_tile_dim_) {
318     ret_index.pop_back();
319   }
320   return ret_index;
321 }
322 
DeviceForTileIndex(absl::Span<const int64> index) const323 int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
324   CHECK(!replicated_);
325   CHECK(!manual_);
326   CHECK(!IsTuple());
327   if (maximal_) {
328     return *tile_assignment_.begin();
329   }
330   if (replicate_on_last_tile_dim_ &&
331       index.size() < tile_assignment().num_dimensions()) {
332     std::vector<int64> first_replicated_index(index.begin(), index.end());
333     first_replicated_index.push_back(0);
334     return tile_assignment_(first_replicated_index);
335   }
336   return tile_assignment_(index);
337 }
338 
TileOffsetForDevice(const Shape & shape,int64_t device) const339 std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
340                                                     int64_t device) const {
341   CHECK(!IsTuple());
342   CHECK(!manual_);
343 
344   if (maximal_) {
345     return std::vector<int64>(shape.dimensions_size(), 0);
346   }
347   if (replicate_on_last_tile_dim_) {
348     CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions() - 1);
349   } else {
350     CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
351   }
352   std::vector<int64> index = TileIndexForDevice(device);
353   for (int64_t i = 0; i < index.size(); ++i) {
354     const int64_t shape_dim = shape.dimensions(i);
355     index[i] = std::min(
356         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
357   }
358   return index;
359 }
360 
TileLimitForDevice(const Shape & shape,int64_t device) const361 std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
362                                                    int64_t device) const {
363   CHECK(!IsTuple());
364   CHECK(!manual_);
365 
366   if (maximal_) {
367     return std::vector<int64>(shape.dimensions().begin(),
368                               shape.dimensions().end());
369   }
370 
371   CHECK_EQ(shape.dimensions_size() + (ReplicateOnLastTileDim() ? 1 : 0),
372            tile_assignment_.num_dimensions());
373   std::vector<int64> index = TileIndexForDevice(device);
374   for (int64_t i = 0; i < index.size(); ++i) {
375     const int64_t shape_dim = shape.dimensions(i);
376     index[i] = std::min(
377         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
378         shape_dim);
379   }
380   return index;
381 }
382 
RequiredLeaves(const Shape & shape)383 int64 HloSharding::RequiredLeaves(const Shape& shape) {
384   // Empty tuples (with arbitrary nesting) have no leaf nodes as far as
385   // ShapeUtil and ShapeTree are concerned, but they do have a single
386   // tuple_elements_ entry since we want to allow empty tuple results to
387   // have sharding.
388   const int64_t leaf_count = ShapeUtil::GetLeafCount(shape);
389   return (leaf_count == 0) ? 1 : leaf_count;
390 }
391 
CheckLeafCount(const Shape & shape) const392 Status HloSharding::CheckLeafCount(const Shape& shape) const {
393   int64_t shape_leaves = RequiredLeaves(shape);
394   TF_RET_CHECK(shape_leaves == tuple_elements_.size())
395       << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
396       << " leaf nodes while this sharding has " << tuple_elements_.size();
397   return Status::OK();
398 }
399 
AsShapeTree(const Shape & shape) const400 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
401     const Shape& shape) const {
402   if (IsTuple()) {
403     ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
404     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
405     auto it = tuple_elements_.begin();
406     for (auto& index_to_sharding : result.leaves()) {
407       index_to_sharding.second = *it++;
408     }
409     if (ShapeUtil::IsEmptyTuple(shape)) {
410       // Empty tuples have no leaves, but we want to assign them a sharding
411       // anyway, so we use the root element sharding.
412       *result.mutable_element(ShapeIndex({})) = *it;
413     }
414     return std::move(result);
415   } else {
416     return ShapeTree<HloSharding>(shape, *this);
417   }
418 }
419 
GetTupleSharding(const Shape & shape) const420 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
421   if (IsTuple()) {
422     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
423     return *this;
424   }
425   return Tuple(ShapeTree<HloSharding>(shape, *this));
426 }
427 
UniqueDevice() const428 absl::optional<int64> HloSharding::UniqueDevice() const {
429   if (IsTuple()) {
430     if (tuple_elements_.empty()) {
431       return absl::nullopt;
432     }
433     absl::optional<int64> unique_device;
434     for (auto& tuple_sharding : tuple_elements_) {
435       auto device = tuple_sharding.UniqueDevice();
436       if (!device || (unique_device && *device != *unique_device)) {
437         return absl::nullopt;
438       }
439       unique_device = device;
440     }
441     return unique_device;
442   }
443   if (!replicated_ && maximal_) {
444     return static_cast<int64>(*tile_assignment_.begin());
445   }
446   return absl::nullopt;
447 }
448 
GetUniqueDevice() const449 int64 HloSharding::GetUniqueDevice() const {
450   auto device = UniqueDevice();
451   CHECK(device) << "Sharding does not have a unique device: " << *this;
452   return *device;
453 }
454 
ValidateTuple(const Shape & shape,int64_t num_devices) const455 Status HloSharding::ValidateTuple(const Shape& shape,
456                                   int64_t num_devices) const {
457   if (!shape.IsTuple()) {
458     return tensorflow::errors::InvalidArgument(
459         StrCat("Sharding is tuple-shaped but validation shape is not."));
460   }
461   TF_RETURN_IF_ERROR(CheckLeafCount(shape));
462 
463   // Now we've validated the number of tuple elements, it's safe to request a
464   // shape tree.
465   ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
466   for (const auto& index_to_sharding : shape_tree.leaves()) {
467     Status status = index_to_sharding.second.ValidateNonTuple(
468         ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
469     if (!status.ok()) {
470       tensorflow::errors::AppendToMessage(
471           &status, StrCat("Note: While validating sharding tuple element ",
472                           index_to_sharding.first.ToString(), " which is ",
473                           index_to_sharding.second.ToString()));
474       return status;
475     }
476   }
477   return Status::OK();
478 }
479 
Validate(const Shape & shape,int64_t num_devices) const480 Status HloSharding::Validate(const Shape& shape, int64_t num_devices) const {
481   if (shape.IsToken()) {
482     return Status::OK();
483   }
484   Status status = IsTuple() ? ValidateTuple(shape, num_devices)
485                             : ValidateNonTuple(shape, num_devices);
486   if (!status.ok()) {
487     tensorflow::errors::AppendToMessage(
488         &status, StrCat("Note: While validating sharding ", ToString(),
489                         " against shape ", ShapeUtil::HumanString(shape)));
490   }
491   return status;
492 }
493 
ValidateNonTuple(const Shape & shape,int64_t num_devices) const494 Status HloSharding::ValidateNonTuple(const Shape& shape,
495                                      int64_t num_devices) const {
496   if (shape.IsTuple()) {
497     return tensorflow::errors::InvalidArgument(
498         StrCat("Validation shape is a tuple but sharding is not."));
499   }
500   if (replicated_) {
501     return Status::OK();
502   }
503 
504   // All tile assignments must be less than the number of available cores and
505   // unique.
506   Status status = Status::OK();
507   absl::flat_hash_set<int64> seen_cores;
508   tile_assignment_.Each([&](absl::Span<const int64> indices, int32_t core) {
509     // Don't overwrite a bad status, so we report the first error.
510     if (status.ok()) {
511       if (core >= num_devices) {
512         status = tensorflow::errors::InvalidArgument(
513             StrCat("core ", core, " > ", num_devices, " in tile assignment"));
514       } else if (seen_cores.contains(core)) {
515         status = tensorflow::errors::InvalidArgument(
516             StrCat("core ", core, " is not unique in tile assignment"));
517       }
518       seen_cores.insert(core);
519     }
520   });
521   if (!status.ok()) {
522     return status;
523   }
524 
525   if (IsTileMaximal() || IsManual()) {
526     return Status::OK();
527   }
528 
529   // The tile assignment tensor must have the same rank as the input, or input
530   // rank + 1 for replicate_on_last_tile_dim_.
531   if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) !=
532       tile_assignment_.num_dimensions()) {
533     return tensorflow::errors::InvalidArgument(
534         "Number of tile assignment dimensions is different to the input rank. "
535         "sharding=",
536         ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
537   }
538 
539   // The correct constructor has to be used to create tile maximal shardings.
540   if (tile_assignment_.num_elements() == 1) {
541     return tensorflow::errors::InvalidArgument(
542         "Tile assignment only contains a single device. If a replicated "
543         "sharding was intended, use HloSharding::Replicated(). If a device "
544         "placement was intended, use HloSharding::AssignDevice()");
545   }
546   return Status::OK();
547 }
548 
FromProto(const OpSharding & proto)549 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
550     const OpSharding& proto) {
551   std::vector<OpMetadata> metadata(proto.metadata().begin(),
552                                    proto.metadata().end());
553   std::vector<int> sharding_types_int(proto.last_tile_dims().begin(),
554                                       proto.last_tile_dims().end());
555   std::vector<OpSharding::Type> sharding_types;
556   absl::c_transform(
557       sharding_types_int, std::back_inserter(sharding_types),
558       [](const int type) { return static_cast<OpSharding::Type>(type); });
559   if (proto.type() == OpSharding::TUPLE) {
560     TF_RET_CHECK(metadata.empty())
561         << "Tuple sharding is expected to have no metadata.";
562     std::vector<HloSharding> tuple_shardings;
563     tuple_shardings.reserve(proto.tuple_shardings().size());
564     for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
565       TF_ASSIGN_OR_RETURN(HloSharding sharding,
566                           HloSharding::FromProto(tuple_sharding_proto));
567       tuple_shardings.push_back(sharding);
568     }
569     return HloSharding(tuple_shardings);
570   } else if (proto.type() == OpSharding::REPLICATED) {
571     return Replicate(metadata);
572   } else if (proto.type() == OpSharding::MANUAL) {
573     return Manual(metadata);
574   } else if (proto.tile_assignment_devices().size() == 1) {
575     return HloSharding(proto.tile_assignment_devices(0), metadata);
576   }
577 
578   TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
579       << "Maximal sharding is expected to have single device assignment, but "
580       << proto.tile_assignment_devices().size() << " has provided.";
581 
582   TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
583   TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
584 
585   // RE: the product of tile assignment tensor dimensions must be
586   // equal to tile_assignment_devices.size().
587   int64_t product_of_dimensions = 1;
588   for (auto dimension : proto.tile_assignment_dimensions()) {
589     TF_RET_CHECK(dimension > 0);
590     product_of_dimensions =
591         MultiplyWithoutOverflow(product_of_dimensions, dimension);
592     TF_RET_CHECK(product_of_dimensions > 0);
593   }
594   TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
595 
596   // Some versions of gcc cannot infer the TileAssignment constructor from a
597   // braced initializer-list, so create one manually.
598   std::vector<int64> devices(proto.tile_assignment_devices().begin(),
599                              proto.tile_assignment_devices().end());
600   Array<int64> tile_assignment(
601       std::vector<int64>(proto.tile_assignment_dimensions().begin(),
602                          proto.tile_assignment_dimensions().end()));
603   std::copy(proto.tile_assignment_devices().begin(),
604             proto.tile_assignment_devices().end(), tile_assignment.begin());
605   if (!sharding_types.empty()) {
606     TF_RET_CHECK(!proto.replicate_on_last_tile_dim());
607     return HloSharding(tile_assignment, sharding_types, metadata);
608   }
609   return proto.replicate_on_last_tile_dim()
610              ? PartialTile(tile_assignment, metadata)
611              : HloSharding(tile_assignment,
612                            /*replicate_on_last_tile_dim=*/false, metadata);
613 }
614 
ToProto() const615 OpSharding HloSharding::ToProto() const {
616   OpSharding result;
617 
618   if (IsTuple()) {
619     CHECK(metadata_.empty());
620     for (const HloSharding& element : tuple_elements_) {
621       *result.add_tuple_shardings() = element.ToProto();
622     }
623     result.set_type(OpSharding::TUPLE);
624     return result;
625   }
626 
627   result.mutable_metadata()->Reserve(metadata_.size());
628   for (const auto& metadata : metadata_) {
629     *result.add_metadata() = metadata;
630   }
631 
632   for (int64_t dim : tile_assignment_.dimensions()) {
633     result.add_tile_assignment_dimensions(dim);
634   }
635   for (auto device : tile_assignment_) {
636     result.add_tile_assignment_devices(device);
637   }
638   if (IsReplicated()) {
639     result.set_type(OpSharding::REPLICATED);
640     result.clear_tile_assignment_dimensions();
641   } else if (IsTileMaximal()) {
642     result.set_type(OpSharding::MAXIMAL);
643   } else if (IsManual()) {
644     result.set_type(OpSharding::MANUAL);
645     result.clear_tile_assignment_dimensions();
646   } else {
647     result.set_type(OpSharding::OTHER);
648     result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
649   }
650   return result;
651 }
652 
TileShape(const Shape & shape) const653 Shape HloSharding::TileShape(const Shape& shape) const {
654   if (IsTileMaximal() || IsManual()) {
655     return shape;
656   }
657   Shape result_shape = shape;
658   for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
659     result_shape.set_dimensions(
660         i, CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)));
661   }
662   return result_shape;
663 }
664 
TileShape(const Shape & shape,int64_t device) const665 Shape HloSharding::TileShape(const Shape& shape, int64_t device) const {
666   if (IsTileMaximal() || IsManual()) {
667     return shape;
668   }
669 
670   std::vector<int64> index = TileIndexForDevice(device);
671   Shape result_shape = shape;
672   for (int64_t i = 0; i < index.size(); ++i) {
673     const int64_t shape_dim = shape.dimensions(i);
674     int64_t offset = std::min(
675         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
676     int64_t limit = std::min(
677         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
678         shape_dim);
679     result_shape.set_dimensions(i, limit - offset);
680   }
681   return result_shape;
682 }
683 
NumTiles() const684 int64 HloSharding::NumTiles() const {
685   if (IsTileMaximal()) {
686     return 1;
687   }
688   CHECK(!IsManual());
689   if (ReplicateOnLastTileDim()) {
690     return tile_assignment().num_elements() /
691            tile_assignment().dimensions().back();
692   }
693   return tile_assignment().num_elements();
694 }
695 
NumTiles(absl::Span<const int64> dims) const696 int64 HloSharding::NumTiles(absl::Span<const int64> dims) const {
697   if (IsTileMaximal()) {
698     return 1;
699   }
700   CHECK(!IsManual());
701   CHECK(!ReplicateOnLastTileDim() ||
702         !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1));
703   int64_t num_tiles = 1;
704   for (auto d : dims) {
705     CHECK(d < tile_assignment().num_dimensions());
706     num_tiles *= tile_assignment().dim(d);
707   }
708   return num_tiles;
709 }
710 
GetSubSharding(const Shape & shape,const ShapeIndex & index) const711 HloSharding HloSharding::GetSubSharding(const Shape& shape,
712                                         const ShapeIndex& index) const {
713   CHECK(IsTuple());
714   int64_t sharding_index = 0;
715   const Shape* sub_shape = &shape;
716   for (int64_t idx : index) {
717     for (int64_t i = 0; i < idx; ++i) {
718       sharding_index +=
719           ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
720     }
721     sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
722   }
723   if (sub_shape->IsTuple()) {
724     auto begin_it = tuple_elements_.begin() + sharding_index;
725     std::vector<HloSharding> sub_shardings(
726         begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
727     return HloSharding::Tuple(*sub_shape, sub_shardings);
728   } else {
729     return tuple_elements_[sharding_index];
730   }
731 }
732 
ExtractSingleSharding() const733 absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
734   if (!IsTuple()) {
735     return *this;
736   }
737   if (tuple_elements_.empty()) {
738     return absl::nullopt;
739   }
740   for (int64_t i = 1; i < tuple_elements_.size(); ++i) {
741     if (tuple_elements_[0] != tuple_elements_[i]) {
742       return absl::nullopt;
743     }
744   }
745   return tuple_elements_.front();
746 }
747 
WithMetadata(absl::Span<const OpMetadata> metadata,bool overwrite) const748 HloSharding HloSharding::WithMetadata(absl::Span<const OpMetadata> metadata,
749                                       bool overwrite) const {
750   auto assign_metadata = [&](HloSharding& sharding) {
751     if (sharding.metadata_.empty() || overwrite) {
752       sharding.metadata_.assign(metadata.begin(), metadata.end());
753     }
754   };
755 
756   HloSharding sharding = *this;
757   if (sharding.IsTuple()) {
758     for (HloSharding& sub_sharding : sharding.tuple_elements()) {
759       assign_metadata(sub_sharding);
760     }
761   } else {
762     assign_metadata(sharding);
763   }
764   return sharding;
765 }
766 
WithoutMetadata() const767 HloSharding HloSharding::WithoutMetadata() const {
768   HloSharding sharding = *this;
769   sharding.metadata_.clear();
770   for (HloSharding& sub_sharding : sharding.tuple_elements()) {
771     sub_sharding.metadata_.clear();
772   }
773   return sharding;
774 }
775 
Hash() const776 size_t HloSharding::Hash() const {
777   if (tuple_) {
778     size_t h = 0;
779     for (const auto& element : tuple_elements_) {
780       h = tensorflow::Hash64Combine(h, element.Hash());
781     }
782     return h;
783   }
784   if (replicated_) {
785     return 0;
786   }
787   if (manual_) {
788     return 1;
789   }
790   size_t h = 0;
791   for (uint32 v : tile_assignment_) {
792     h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
793   }
794   if (replicate_on_last_tile_dim_) {
795     h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(1));
796   }
797   return h;
798 }
799 
operator <<(std::ostream & out,const HloSharding & sharding)800 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
801   out << sharding.ToString();
802   return out;
803 }
804 
805 }  // namespace xla
806