1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/layout_util.h"
17 
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/core/lib/core/status.h"
21 
22 namespace tensorflow {
23 
ShapeDeterminationFns()24 XlaShapeLayoutHelpers::ShapeDeterminationFns::ShapeDeterminationFns() {
25   layout_preference_fn = UseNoPreferenceLayoutFn();
26   shape_representation_fn = IdentityShapeRepresentationFn();
27 }
28 
UseNoPreferenceLayoutFn()29 XlaShapeLayoutHelpers::LayoutPreferenceFn UseNoPreferenceLayoutFn() {
30   return [](const TensorShape& shape, DataType dtype,
31             std::optional<XlaArgument::Kind>) -> XlaLayoutPreference {
32     return XlaLayoutPreference::kNoPreference;
33   };
34 }
35 
36 // Rewrites the layout of xla_shape if there is tiled sharding.
RewriteLayoutWithShardedShape(const std::optional<xla::HloSharding> & sharding,bool use_fast_memory,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,xla::Shape * xla_shape)37 Status RewriteLayoutWithShardedShape(
38     const std::optional<xla::HloSharding>& sharding, bool use_fast_memory,
39     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
40     xla::Shape* xla_shape) {
41   if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
42     // After sharding, per core shape might have different layout. For example,
43     // before sharding, a shape [128, 128] will be assigned default
44     // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
45     // the sharded shapes will have minor-to-major {0, 1}.
46     //
47     // As a result, for sharded shapes, we set their layout to per core shape's
48     // layout.
49     //
50     // TODO(endlessroad): for variable input & update, we might have
51     // different layouts which will prevent input output aliasing and
52     // increase memory usage. Investigate such cases.
53     int64_t device = *sharding->tile_assignment().begin();
54     std::vector<int64_t> offset =
55         sharding->TileOffsetForDevice(*xla_shape, device);
56     std::vector<int64_t> limit =
57         sharding->TileLimitForDevice(*xla_shape, device);
58     std::vector<int64_t> dimensions(xla_shape->rank());
59     for (int64_t i = 0; i < xla_shape->rank(); ++i) {
60       dimensions[i] = limit[i] - offset[i];
61     }
62     xla::Shape per_device_xla_shape =
63         xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
64     TensorShape per_device_tensor_shape;
65     TF_RETURN_IF_ERROR(
66         XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
67     TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
68                                             xla_shape->element_type()));
69     auto layout_preference = shape_determination_fns.layout_preference_fn(
70         per_device_tensor_shape, dtype, std::nullopt);
71     TF_ASSIGN_OR_RETURN(per_device_xla_shape,
72                         shape_determination_fns.shape_representation_fn(
73                             per_device_tensor_shape, dtype, use_fast_memory,
74                             layout_preference));
75     *xla_shape->mutable_layout() = per_device_xla_shape.layout();
76   }
77   return OkStatus();
78 }
79 
80 // There is a shape_representation_fn or sharding for an output, this function
81 // uses a reshape to fix the layout.
ReshapeWithCorrectRepresentationAndSharding(xla::XlaBuilder * builder,xla::XlaOp original,xla::Shape original_shape,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,std::optional<xla::OpSharding> sharding,bool fast_mem)82 StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
83     xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
84     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
85     std::optional<xla::OpSharding> sharding, bool fast_mem) {
86   if (original_shape.IsTuple()) {
87     std::vector<xla::XlaOp> elements;
88     for (int i = 0; i < original_shape.tuple_shapes_size(); ++i) {
89       auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
90       TF_ASSIGN_OR_RETURN(auto element,
91                           ReshapeWithCorrectRepresentationAndSharding(
92                               builder, xla::GetTupleElement(original, i),
93                               original_shape.tuple_shapes(i),
94                               shape_determination_fns, subsharding, fast_mem));
95       elements.push_back(element);
96     }
97     return xla::Tuple(builder, elements);
98   }
99   if (!original_shape.IsArray()) return original;
100   TensorShape shape;
101   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
102   TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
103                                           original_shape.element_type()));
104   auto layout_preference =
105       shape_determination_fns.layout_preference_fn(shape, dtype, std::nullopt);
106   TF_ASSIGN_OR_RETURN(auto to_shape,
107                       shape_determination_fns.shape_representation_fn(
108                           shape, dtype, fast_mem, layout_preference));
109   if (sharding) {
110     TF_ASSIGN_OR_RETURN(auto hlo_sharding,
111                         xla::HloSharding::FromProto(*sharding));
112 
113     TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
114         hlo_sharding, fast_mem, shape_determination_fns, &to_shape));
115   }
116   if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
117     for (int64_t i = 0; i < original_shape.rank(); ++i) {
118       to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
119     }
120   }
121   return xla::Reshape(to_shape, original);
122 }
123 
124 }  // namespace tensorflow
125