1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17
18 #include <numeric>
19
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24
25 namespace tensorflow {
26 namespace {
27
PopulateInfeedLayoutVector(const xla::Shape & shape,std::vector<int> * layouts)28 Status PopulateInfeedLayoutVector(const xla::Shape& shape,
29 std::vector<int>* layouts) {
30 if (shape.IsTuple()) {
31 int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape);
32 for (int64 i = 0; i < tuple_elements; ++i) {
33 const xla::Shape& subshape =
34 xla::ShapeUtil::GetTupleElementShape(shape, i);
35 TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts));
36 }
37 } else if (xla::LayoutUtil::HasLayout(shape)) {
38 for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) {
39 layouts->push_back(dim);
40 }
41 } else {
42 layouts->insert(layouts->end(), shape.rank(), -1);
43 }
44 return Status::OK();
45 }
46
47 // Populate the output layout unless the minor_to_major array contains all -1
48 // value, in which case the layout is considered missing and the API returns
49 // false.
MakeLayout(absl::Span<const int64> minor_to_major,xla::Layout * layout)50 xla::StatusOr<bool> MakeLayout(absl::Span<const int64> minor_to_major,
51 xla::Layout* layout) {
52 if (std::all_of(minor_to_major.begin(), minor_to_major.end(),
53 [](int64 dim) { return dim == -1; })) {
54 return false;
55 }
56 std::vector<bool> dim_present(minor_to_major.size(), false);
57 for (auto dim : minor_to_major) {
58 const int minor_to_major_size = minor_to_major.size();
59 if (dim < 0 || dim >= minor_to_major_size) {
60 return errors::InvalidArgument("Layout dimension out of range: dim=", dim,
61 " rank=", minor_to_major.size());
62 }
63 if (dim_present[dim]) {
64 return errors::InvalidArgument("Repeated layout dimension: dim=", dim);
65 }
66 dim_present[dim] = true;
67 }
68 *layout = xla::LayoutUtil::MakeLayout(minor_to_major);
69 return true;
70 }
71
AssignLayout(absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * shape)72 Status AssignLayout(
73 absl::Span<const int64> minor_to_major,
74 const std::function<xla::Layout(const xla::Shape&)>& layout_func,
75 xla::Shape* shape) {
76 xla::Layout layout;
77 TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout));
78 if (!has_layout && layout_func) {
79 layout = layout_func(*shape);
80 }
81 *shape->mutable_layout() = layout;
82 return Status::OK();
83 }
84
85 } // namespace
86
87 // Convert an XLA Shape into the equivalent TensorFlow shape.
XLAShapeToTensorShape(const xla::Shape & shape,TensorShape * tensor_shape)88 Status XLAShapeToTensorShape(const xla::Shape& shape,
89 TensorShape* tensor_shape) {
90 if (shape.IsTuple()) {
91 return errors::InvalidArgument("XLA shape ",
92 xla::ShapeUtil::HumanString(shape),
93 " cannot be converted to a TensorShape");
94 }
95 *tensor_shape = TensorShape();
96 for (int i = 0; i < shape.rank(); ++i) {
97 tensor_shape->AddDim(shape.dimensions(i));
98 }
99 return Status::OK();
100 }
101
102 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const PartialTensorShape & tensor_shape,xla::Shape * shape)103 Status TensorShapeToXLAShape(DataType dtype,
104 const PartialTensorShape& tensor_shape,
105 xla::Shape* shape) {
106 xla::PrimitiveType type;
107 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
108 *shape = TensorShapeToXLAShape(type, tensor_shape);
109 return Status::OK();
110 }
111
TensorShapeToXLAShape(xla::PrimitiveType type,const PartialTensorShape & tensor_shape)112 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
113 const PartialTensorShape& tensor_shape) {
114 if (tensor_shape.unknown_rank()) {
115 // For unknown shape, create a rank 1 size 0 tensor.
116 return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
117 }
118 int rank = tensor_shape.dims();
119 std::vector<int64> dimensions(rank);
120 std::vector<bool> dynamic_dimensions(rank, false);
121 std::vector<int64> layout(rank);
122 for (int d = 0; d < rank; ++d) {
123 dimensions[d] = tensor_shape.dim_size(d);
124 if (dimensions[d] < 0) {
125 dynamic_dimensions[d] = true;
126 // TODO(b/177329258): Consider improving this/enabling MakeShapeWithLayout
127 // to work wuith dynamic shapes.
128 LOG(WARNING) << "Unable to convert TF shape with dynamic size to XLA "
129 "shape; returning unknown sentinel value";
130 return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
131 }
132 }
133 // XLA uses minor-to-major; Tensorflow uses major-to-minor.
134 std::iota(layout.rbegin(), layout.rend(), 0);
135 xla::Shape result =
136 xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
137
138 for (int64 d = 0; d < rank; ++d) {
139 result.set_dynamic_dimension(d, dynamic_dimensions[d]);
140 }
141 return result;
142 }
143
144 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape,xla::Shape * shape)145 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
146 xla::Shape* shape) {
147 xla::PrimitiveType type;
148 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
149 *shape = TensorShapeToXLAShape(type, tensor_shape);
150 return Status::OK();
151 }
152
TensorShapeToXLAShape(xla::PrimitiveType type,const TensorShape & tensor_shape)153 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
154 const TensorShape& tensor_shape) {
155 int rank = tensor_shape.dims();
156 std::vector<int64> dimensions(rank);
157 std::vector<int64> layout(rank);
158 for (int d = 0; d < rank; ++d) {
159 dimensions[d] = tensor_shape.dim_size(d);
160 }
161 // XLA uses minor-to-major; Tensorflow uses major-to-minor.
162 std::iota(layout.rbegin(), layout.rend(), 0);
163
164 return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
165 }
166
GetShapeLayoutVector(const xla::Shape & shape)167 xla::StatusOr<std::vector<int>> GetShapeLayoutVector(const xla::Shape& shape) {
168 std::vector<int> layouts;
169 TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts));
170 return layouts;
171 }
172
GetShapeWithLayout(const xla::Shape & input_shape,absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * output_shape)173 Status GetShapeWithLayout(
174 const xla::Shape& input_shape, absl::Span<const int64> minor_to_major,
175 const std::function<xla::Layout(const xla::Shape&)>& layout_func,
176 xla::Shape* output_shape) {
177 if (input_shape.IsTuple()) {
178 int64 tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape);
179 std::vector<xla::Shape> shapes;
180 shapes.reserve(tuple_elements);
181 size_t position = 0;
182 for (int64 i = 0; i < tuple_elements; ++i) {
183 const xla::Shape& shape =
184 xla::ShapeUtil::GetTupleElementShape(input_shape, i);
185 if (shape.IsTuple()) {
186 return errors::InvalidArgument(
187 "Nested tuples not supported: ",
188 xla::ShapeUtil::HumanString(input_shape));
189 }
190 int64 rank = shape.rank();
191 if (position + rank > minor_to_major.size()) {
192 return errors::InvalidArgument(
193 "Not enough layout attribute elements: position=", position,
194 " rank=", rank, " elements=", minor_to_major.size());
195 }
196 shapes.push_back(shape);
197 TF_RETURN_IF_ERROR(AssignLayout(
198 absl::Span<const int64>(minor_to_major).subspan(position, rank),
199 layout_func, &shapes.back()));
200 position += rank;
201
202 VLOG(4) << "Shape[" << i
203 << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back());
204 }
205 if (position != minor_to_major.size()) {
206 return errors::InvalidArgument(
207 "Too many elements passed in the layout attribute: position=",
208 position, " size=", minor_to_major.size());
209 }
210 *output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
211 } else {
212 int64 rank = input_shape.rank();
213 const int64 minor_to_major_size = minor_to_major.size();
214 if (rank != minor_to_major_size) {
215 return errors::InvalidArgument(
216 "Wrong number of layout attribute elements: rank=", rank,
217 " elements=", minor_to_major.size());
218 }
219 *output_shape = input_shape;
220 TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape));
221
222 VLOG(4) << "Shape[] = "
223 << xla::ShapeUtil::HumanStringWithLayout(*output_shape);
224 }
225 return Status::OK();
226 }
227
228 } // namespace tensorflow
229