• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/layout_util.h"
17 
18 namespace mlir {
19 
20 // Rewrites the layout of xla_shape if there is tiled sharding.
RewriteLayoutWithShardedShape(const std::optional<xla::HloSharding> & sharding,bool use_fast_memory,const LayoutPreferenceFn & layout_preference_fn,const ShapeRepresentationFn & shape_representation_fn,xla::Shape * xla_shape)21 xla::Status RewriteLayoutWithShardedShape(
22     const std::optional<xla::HloSharding>& sharding, bool use_fast_memory,
23     const LayoutPreferenceFn& layout_preference_fn,
24     const ShapeRepresentationFn& shape_representation_fn,
25     xla::Shape* xla_shape) {
26   if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
27     // After sharding, per core shape might have different layout. For example,
28     // before sharding, a shape [128, 128] will be assigned default
29     // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
30     // the sharded shapes will have minor-to-major {0, 1}.
31     //
32     // As a result, for sharded shapes, we set their layout to per core shape's
33     // layout.
34     //
35     // TODO(endlessroad): for variable input & update, we might have
36     // different layouts which will prevent input output aliasing and
37     // increase memory usage. Investigate such cases.
38     int64_t device = *sharding->tile_assignment().begin();
39     std::vector<int64_t> offset =
40         sharding->TileOffsetForDevice(*xla_shape, device);
41     std::vector<int64_t> limit =
42         sharding->TileLimitForDevice(*xla_shape, device);
43     std::vector<int64_t> dimensions(xla_shape->rank());
44     for (int64_t i = 0; i < xla_shape->rank(); ++i) {
45       dimensions[i] = limit[i] - offset[i];
46     }
47     xla::Shape per_device_xla_shape =
48         xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
49     TF_ASSIGN_OR_RETURN(auto layout_preference,
50                         layout_preference_fn
51                             ? layout_preference_fn(per_device_xla_shape)
52                             : XlaLayoutPreference::kNoPreference);
53     TF_ASSIGN_OR_RETURN(
54         per_device_xla_shape,
55         shape_representation_fn
56             ? shape_representation_fn(per_device_xla_shape, use_fast_memory,
57                                       layout_preference)
58             : per_device_xla_shape);
59     *xla_shape->mutable_layout() = per_device_xla_shape.layout();
60   }
61   return xla::Status::OK();
62 }
63 
64 // There is a shape_representation_fn or sharding for an output, this function
65 // uses a reshape to fix the layout.
ReshapeWithCorrectRepresentationAndSharding(xla::XlaBuilder * builder,xla::XlaOp original,xla::Shape original_shape,const LayoutPreferenceFn & layout_preference_fn,const ShapeRepresentationFn & shape_representation_fn,std::optional<xla::OpSharding> sharding,bool fast_mem)66 xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
67     xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
68     const LayoutPreferenceFn& layout_preference_fn,
69     const ShapeRepresentationFn& shape_representation_fn,
70     std::optional<xla::OpSharding> sharding, bool fast_mem) {
71   if (original_shape.IsTuple()) {
72     std::vector<xla::XlaOp> elements;
73     for (int i = 0; i < original_shape.tuple_shapes_size(); ++i) {
74       auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
75       TF_ASSIGN_OR_RETURN(
76           auto element,
77           ReshapeWithCorrectRepresentationAndSharding(
78               builder, xla::GetTupleElement(original, i),
79               original_shape.tuple_shapes(i), layout_preference_fn,
80               shape_representation_fn, subsharding, fast_mem));
81       elements.push_back(element);
82     }
83     return xla::Tuple(builder, elements);
84   }
85   if (!original_shape.IsArray()) return original;
86   TF_ASSIGN_OR_RETURN(auto layout_preference,
87                       layout_preference_fn
88                           ? layout_preference_fn(original_shape)
89                           : XlaLayoutPreference::kNoPreference);
90   TF_ASSIGN_OR_RETURN(
91       auto to_shape,
92       shape_representation_fn
93           ? shape_representation_fn(original_shape, fast_mem, layout_preference)
94           : original_shape);
95   if (sharding) {
96     TF_ASSIGN_OR_RETURN(auto hlo_sharding,
97                         xla::HloSharding::FromProto(*sharding));
98 
99     TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
100         hlo_sharding, fast_mem, layout_preference_fn, shape_representation_fn,
101         &to_shape));
102   }
103   if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
104     for (int64_t i = 0; i < original_shape.rank(); ++i) {
105       to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
106     }
107   }
108   return xla::Reshape(to_shape, original);
109 }
110 
111 }  // namespace mlir
112