1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/overflow_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23
24 namespace xla {
25
26 using absl::StrCat;
27 using absl::StrJoin;
28
AssignDevice(int64 device_id)29 HloSharding HloSharding::AssignDevice(int64 device_id) {
30 return HloSharding(device_id);
31 }
32
Tile1D(const Shape & input_shape,int64 num_tiles)33 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
34 CHECK_EQ(1, input_shape.rank());
35 CHECK_GT(num_tiles, 1);
36 std::vector<int64> dimensions(1, num_tiles);
37 Array<int64> assignment(dimensions);
38 std::iota(assignment.begin(), assignment.end(), 0);
39 return HloSharding(assignment);
40 }
41
Tuple(const ShapeTree<HloSharding> & sub_shardings)42 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
43 std::vector<HloSharding> flattened_list;
44 flattened_list.reserve(sub_shardings.leaf_count());
45 for (const auto& index_to_sharding : sub_shardings.leaves()) {
46 flattened_list.push_back(index_to_sharding.second);
47 }
48 if (flattened_list.empty()) {
49 // Empty tuple sharding ends up having no leaves, but we want to allow
50 // empty tuple HLO instruction results to have sharding, so we fetch the
51 // root ({}) sharding value from the ShapeTree.
52 // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
53 // init as value at its root.
54 flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
55 }
56 return HloSharding(flattened_list);
57 }
58
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)59 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
60 absl::Span<const HloSharding> shardings) {
61 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
62 for (auto& sharding : shardings) {
63 CHECK(!sharding.IsTuple()) << sharding.ToString();
64 }
65 std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
66 CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
67 << "Flat list has " << flattened_list.size() << ", required "
68 << RequiredLeaves(tuple_shape);
69 return HloSharding(flattened_list);
70 }
71
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)72 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
73 const HloSharding& sharding) {
74 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
75 CHECK(!sharding.IsTuple()) << sharding.ToString();
76 int64 leaf_count = RequiredLeaves(tuple_shape);
77 std::vector<HloSharding> flattened_list;
78 flattened_list.resize(leaf_count, sharding);
79 return HloSharding(flattened_list);
80 }
81
Single(const Shape & shape,const HloSharding & sharding)82 HloSharding HloSharding::Single(const Shape& shape,
83 const HloSharding& sharding) {
84 return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
85 }
86
ToString() const87 string HloSharding::ToString() const {
88 if (IsTuple()) {
89 std::vector<string> parts;
90 parts.reserve(tuple_elements_.size());
91 for (const HloSharding& element : tuple_elements_) {
92 parts.push_back(element.ToString());
93 }
94 return StrCat("{", absl::StrJoin(parts, ", "), "}");
95 }
96
97 if (replicated_) {
98 return "{replicated}";
99 }
100 if (maximal_) {
101 return StrCat(
102 "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
103 }
104 return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
105 StrJoin(tile_assignment_, ","), "}");
106 }
107
UsesDevice(int64 device) const108 bool HloSharding::UsesDevice(int64 device) const {
109 if (IsTuple()) {
110 return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
111 return s.UsesDevice(device);
112 });
113 }
114 const auto& devices = tile_assignment_;
115 return replicated_ || absl::c_linear_search(devices, device);
116 }
117
UsedDevices(int64 * count) const118 std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
119 int64 element_count = 1;
120 std::map<int64, int64> device_map;
121 if (IsTuple()) {
122 for (auto& tuple_element_sharding : tuple_elements()) {
123 auto unique_device = tuple_element_sharding.UniqueDevice();
124 if (unique_device) {
125 device_map[*unique_device] += 1;
126 }
127 }
128 element_count = tuple_elements().size();
129 } else {
130 auto unique_device = UniqueDevice();
131 if (unique_device) {
132 device_map[*unique_device] += 1;
133 }
134 }
135 if (count != nullptr) {
136 *count = element_count;
137 }
138 return device_map;
139 }
140
TileIndexForDevice(int64 device) const141 std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
142 CHECK(!maximal_);
143 CHECK(!IsTuple());
144 std::vector<int64> ret_index;
145 tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
146 if (d == device) {
147 ret_index = {index.begin(), index.end()};
148 }
149 });
150 CHECK(!ret_index.empty());
151 return ret_index;
152 }
153
DeviceForTileIndex(absl::Span<const int64> index) const154 int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
155 CHECK(!replicated_);
156 CHECK(!IsTuple());
157 if (maximal_) {
158 return *tile_assignment_.begin();
159 }
160 return tile_assignment_(index);
161 }
162
TileOffsetForDevice(const Shape & shape,int64 device) const163 std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
164 int64 device) const {
165 CHECK(!IsTuple());
166
167 if (maximal_) {
168 return std::vector<int64>(shape.dimensions_size(), 0);
169 }
170
171 CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
172 std::vector<int64> index = TileIndexForDevice(device);
173 for (int64 i = 0; i < index.size(); ++i) {
174 const int64 shape_dim = shape.dimensions(i);
175 index[i] = std::min(
176 index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
177 }
178 return index;
179 }
180
TileLimitForDevice(const Shape & shape,int64 device) const181 std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
182 int64 device) const {
183 CHECK(!IsTuple());
184
185 if (maximal_) {
186 return std::vector<int64>(shape.dimensions().begin(),
187 shape.dimensions().end());
188 }
189
190 CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
191 std::vector<int64> index = TileIndexForDevice(device);
192 for (int64 i = 0; i < index.size(); ++i) {
193 const int64 shape_dim = shape.dimensions(i);
194 index[i] = std::min(
195 (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
196 shape_dim);
197 }
198 return index;
199 }
200
RequiredLeaves(const Shape & shape)201 int64 HloSharding::RequiredLeaves(const Shape& shape) {
202 // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are
203 // concerned, but they do have a single tuple_elements_ entry since we want
204 // to allow empty tuple results to have sharding.
205 return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape);
206 }
207
CheckLeafCount(const Shape & shape) const208 Status HloSharding::CheckLeafCount(const Shape& shape) const {
209 int64 shape_leaves = RequiredLeaves(shape);
210 TF_RET_CHECK(shape_leaves == tuple_elements_.size())
211 << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
212 << " leaf nodes while this sharding has " << tuple_elements_.size();
213 return Status::OK();
214 }
215
AsShapeTree(const Shape & shape) const216 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
217 const Shape& shape) const {
218 if (IsTuple()) {
219 ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
220 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
221 auto it = tuple_elements_.begin();
222 for (auto& index_to_sharding : result.leaves()) {
223 index_to_sharding.second = *it++;
224 }
225 if (ShapeUtil::IsEmptyTuple(shape)) {
226 // Empty tuples have no leaves, but we want to assign them a sharding
227 // anyway, so we use the root element sharding.
228 *result.mutable_element(ShapeIndex({})) = *it;
229 }
230 return std::move(result);
231 } else {
232 return ShapeTree<HloSharding>(shape, *this);
233 }
234 }
235
GetTupleSharding(const Shape & shape) const236 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
237 if (IsTuple()) {
238 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
239 return *this;
240 }
241 return Tuple(ShapeTree<HloSharding>(shape, *this));
242 }
243
UniqueDevice() const244 absl::optional<int64> HloSharding::UniqueDevice() const {
245 if (IsTuple()) {
246 if (tuple_elements_.empty()) {
247 return absl::nullopt;
248 }
249 absl::optional<int64> unique_device;
250 for (auto& tuple_sharding : tuple_elements_) {
251 auto device = tuple_sharding.UniqueDevice();
252 if (!device || (unique_device && *device != *unique_device)) {
253 return absl::nullopt;
254 }
255 unique_device = device;
256 }
257 return unique_device;
258 }
259 if (!replicated_ && maximal_) {
260 return static_cast<int64>(*tile_assignment_.begin());
261 }
262 return absl::nullopt;
263 }
264
GetUniqueDevice() const265 int64 HloSharding::GetUniqueDevice() const {
266 auto device = UniqueDevice();
267 CHECK(device) << "Sharding does not have a unique device: " << *this;
268 return *device;
269 }
270
ValidateTuple(const Shape & shape,int64 num_devices) const271 Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
272 if (!shape.IsTuple()) {
273 return tensorflow::errors::InvalidArgument(
274 StrCat("Sharding is tuple-shaped but validation shape is not."));
275 }
276 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
277
278 // Now we've validated the number of tuple elements, it's safe to request a
279 // shape tree.
280 ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
281 for (const auto& index_to_sharding : shape_tree.leaves()) {
282 Status status = index_to_sharding.second.ValidateNonTuple(
283 ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
284 if (!status.ok()) {
285 tensorflow::errors::AppendToMessage(
286 &status, StrCat("Note: While validating sharding tuple element ",
287 index_to_sharding.first.ToString(), " which is ",
288 index_to_sharding.second.ToString()));
289 return status;
290 }
291 }
292 return Status::OK();
293 }
294
Validate(const Shape & shape,int64 num_devices) const295 Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
296 Status status = IsTuple() ? ValidateTuple(shape, num_devices)
297 : ValidateNonTuple(shape, num_devices);
298 if (!status.ok()) {
299 tensorflow::errors::AppendToMessage(
300 &status, StrCat("Note: While validating sharding ", ToString(),
301 " against shape ", ShapeUtil::HumanString(shape)));
302 }
303 return status;
304 }
305
ValidateNonTuple(const Shape & shape,int64 num_devices) const306 Status HloSharding::ValidateNonTuple(const Shape& shape,
307 int64 num_devices) const {
308 if (shape.IsTuple()) {
309 return tensorflow::errors::InvalidArgument(
310 StrCat("Validation shape is a tuple but sharding is not."));
311 }
312 if (replicated_) {
313 return Status::OK();
314 }
315
316 // All tile assignments must be less than the number of available cores and
317 // unique.
318 Status status = Status::OK();
319 absl::flat_hash_set<int64> seen_cores;
320 tile_assignment_.Each(
321 [&](absl::Span<const int64> indices, int32 core) {
322 // Don't overwrite a bad status, so we report the first error.
323 if (status.ok()) {
324 if (core >= num_devices) {
325 status = tensorflow::errors::InvalidArgument(StrCat(
326 "core ", core, " > ", num_devices, " in tile assignment"));
327 } else if (seen_cores.contains(core)) {
328 status = tensorflow::errors::InvalidArgument(
329 StrCat("core ", core, " is not unique in tile assignment"));
330 }
331 seen_cores.insert(core);
332 }
333 });
334 if (!status.ok()) {
335 return status;
336 }
337
338 if (IsTileMaximal()) {
339 return Status::OK();
340 }
341
342 // The tile assignment tensor must have the same rank as the input.
343 if (shape.rank() != tile_assignment_.num_dimensions()) {
344 return tensorflow::errors::InvalidArgument(
345 "Number of tile assignment dimensions is different to the input rank. "
346 "sharding=",
347 ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
348 }
349
350 // The correct constructor has to be used to create tile maximal shardings.
351 if (tile_assignment_.num_elements() == 1) {
352 return tensorflow::errors::InvalidArgument(
353 "Tile assignment only contains a single device. If a replicated "
354 "sharding was intended, use HloSharding::Replicated(). If a device "
355 "placement was intended, use HloSharding::AssignDevice()");
356 }
357 return Status::OK();
358 }
359
FromProto(const OpSharding & proto)360 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
361 const OpSharding& proto) {
362 if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) {
363 std::vector<HloSharding> tuple_shardings;
364 tuple_shardings.reserve(proto.tuple_shardings().size());
365 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
366 TF_ASSIGN_OR_RETURN(HloSharding sharding,
367 HloSharding::FromProto(tuple_sharding_proto));
368 tuple_shardings.push_back(sharding);
369 }
370 return HloSharding(tuple_shardings);
371 } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
372 return Replicate();
373 } else if (proto.tile_assignment_devices().size() == 1) {
374 return HloSharding(proto.tile_assignment_devices(0));
375 }
376
377 TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL)
378 << "Maximal sharding is expected to have single device assignment, but "
379 << proto.tile_assignment_devices().size() << " has provided.";
380
381 TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
382 TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
383
384 // RE: the product of tile assignment tensor dimensions must be
385 // equal to tile_assignment_devices.size().
386 int64 product_of_dimensions = 1;
387 for (auto dimension : proto.tile_assignment_dimensions()) {
388 TF_RET_CHECK(dimension > 0);
389 product_of_dimensions =
390 MultiplyWithoutOverflow(product_of_dimensions, dimension);
391 TF_RET_CHECK(product_of_dimensions > 0);
392 }
393 TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
394
395 // Some versions of gcc cannot infer the TileAssignment constructor from a
396 // braced initializer-list, so create one manually.
397 std::vector<int64> devices(proto.tile_assignment_devices().begin(),
398 proto.tile_assignment_devices().end());
399 Array<int64> tile_assignment(
400 std::vector<int64>(proto.tile_assignment_dimensions().begin(),
401 proto.tile_assignment_dimensions().end()));
402 std::copy(proto.tile_assignment_devices().begin(),
403 proto.tile_assignment_devices().end(), tile_assignment.begin());
404 return HloSharding(tile_assignment);
405 }
406
ToProto() const407 OpSharding HloSharding::ToProto() const {
408 OpSharding result;
409
410 if (IsTuple()) {
411 for (const HloSharding& element : tuple_elements_) {
412 *result.add_tuple_shardings() = element.ToProto();
413 }
414 result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
415 return result;
416 }
417
418 for (int64 dim : tile_assignment_.dimensions()) {
419 result.add_tile_assignment_dimensions(dim);
420 }
421 for (auto device : tile_assignment_) {
422 result.add_tile_assignment_devices(device);
423 }
424 if (IsReplicated()) {
425 result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
426 } else if (IsTileMaximal()) {
427 result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
428 } else {
429 result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
430 }
431 return result;
432 }
433
TileShape(const Shape & shape) const434 Shape HloSharding::TileShape(const Shape& shape) const {
435 if (IsTileMaximal()) {
436 return shape;
437 }
438 Shape result_shape = shape;
439 for (int64 i = 0; i < shape.dimensions_size(); ++i) {
440 result_shape.set_dimensions(
441 i, CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)));
442 }
443 return result_shape;
444 }
445
GetSubSharding(const Shape & shape,const ShapeIndex & index) const446 HloSharding HloSharding::GetSubSharding(const Shape& shape,
447 const ShapeIndex& index) const {
448 CHECK(IsTuple());
449 int64 sharding_index = 0;
450 const Shape* sub_shape = &shape;
451 for (int64 idx : index) {
452 for (int64 i = 0; i < idx; ++i) {
453 sharding_index +=
454 ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
455 }
456 sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
457 }
458 if (sub_shape->IsTuple()) {
459 auto begin_it = tuple_elements_.begin() + sharding_index;
460 std::vector<HloSharding> sub_shardings(
461 begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
462 return HloSharding::Tuple(*sub_shape, sub_shardings);
463 } else {
464 return tuple_elements_[sharding_index];
465 }
466 }
467
ExtractSingleSharding() const468 absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
469 if (!IsTuple()) {
470 return *this;
471 }
472 if (tuple_elements_.empty()) {
473 return absl::nullopt;
474 }
475 for (int64 i = 1; i < tuple_elements_.size(); ++i) {
476 if (tuple_elements_[0] != tuple_elements_[i]) {
477 return absl::nullopt;
478 }
479 }
480 return tuple_elements_.front();
481 }
482
Hash() const483 size_t HloSharding::Hash() const {
484 if (tuple_) {
485 size_t h = 0;
486 for (const auto& element : tuple_elements_) {
487 h = tensorflow::Hash64Combine(h, element.Hash());
488 }
489 return h;
490 }
491 if (replicated_) {
492 return 0;
493 }
494 size_t h = 0;
495 for (uint32 v : tile_assignment_) {
496 h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
497 }
498 return h;
499 }
500
operator <<(std::ostream & out,const HloSharding & sharding)501 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
502 out << sharding.ToString();
503 return out;
504 }
505
506 } // namespace xla
507