• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17 
18 #include <numeric>
19 
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 namespace {
27 
PopulateInfeedLayoutVector(const xla::Shape & shape,std::vector<int> * layouts)28 Status PopulateInfeedLayoutVector(const xla::Shape& shape,
29                                   std::vector<int>* layouts) {
30   if (shape.IsTuple()) {
31     int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape);
32     for (int64 i = 0; i < tuple_elements; ++i) {
33       const xla::Shape& subshape =
34           xla::ShapeUtil::GetTupleElementShape(shape, i);
35       TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts));
36     }
37   } else if (xla::LayoutUtil::HasLayout(shape)) {
38     for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) {
39       layouts->push_back(dim);
40     }
41   } else {
42     layouts->insert(layouts->end(), shape.rank(), -1);
43   }
44   return Status::OK();
45 }
46 
47 // Populate the output layout unless the minor_to_major array contains all -1
48 // value, in which case the layout is considered missing and the API returns
49 // false.
MakeLayout(absl::Span<const int64> minor_to_major,xla::Layout * layout)50 xla::StatusOr<bool> MakeLayout(absl::Span<const int64> minor_to_major,
51                                xla::Layout* layout) {
52   if (std::all_of(minor_to_major.begin(), minor_to_major.end(),
53                   [](int64 dim) { return dim == -1; })) {
54     return false;
55   }
56   std::vector<bool> dim_present(minor_to_major.size(), false);
57   for (auto dim : minor_to_major) {
58     const int minor_to_major_size = minor_to_major.size();
59     if (dim < 0 || dim >= minor_to_major_size) {
60       return errors::InvalidArgument("Layout dimension out of range: dim=", dim,
61                                      " rank=", minor_to_major.size());
62     }
63     if (dim_present[dim]) {
64       return errors::InvalidArgument("Repeated layout dimension: dim=", dim);
65     }
66     dim_present[dim] = true;
67   }
68   *layout = xla::LayoutUtil::MakeLayout(minor_to_major);
69   return true;
70 }
71 
AssignLayout(absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * shape)72 Status AssignLayout(
73     absl::Span<const int64> minor_to_major,
74     const std::function<xla::Layout(const xla::Shape&)>& layout_func,
75     xla::Shape* shape) {
76   xla::Layout layout;
77   TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout));
78   if (!has_layout && layout_func) {
79     layout = layout_func(*shape);
80   }
81   *shape->mutable_layout() = layout;
82   return Status::OK();
83 }
84 
85 }  // namespace
86 
87 // Convert an XLA Shape into the equivalent TensorFlow shape.
XLAShapeToTensorShape(const xla::Shape & shape,TensorShape * tensor_shape)88 Status XLAShapeToTensorShape(const xla::Shape& shape,
89                              TensorShape* tensor_shape) {
90   if (shape.IsTuple()) {
91     return errors::InvalidArgument("XLA shape ",
92                                    xla::ShapeUtil::HumanString(shape),
93                                    " cannot be converted to a TensorShape");
94   }
95   *tensor_shape = TensorShape();
96   for (int i = 0; i < shape.rank(); ++i) {
97     tensor_shape->AddDim(shape.dimensions(i));
98   }
99   return Status::OK();
100 }
101 
102 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const PartialTensorShape & tensor_shape,xla::Shape * shape)103 Status TensorShapeToXLAShape(DataType dtype,
104                              const PartialTensorShape& tensor_shape,
105                              xla::Shape* shape) {
106   xla::PrimitiveType type;
107   TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
108   *shape = TensorShapeToXLAShape(type, tensor_shape);
109   return Status::OK();
110 }
111 
TensorShapeToXLAShape(xla::PrimitiveType type,const PartialTensorShape & tensor_shape)112 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
113                                  const PartialTensorShape& tensor_shape) {
114   if (tensor_shape.unknown_rank()) {
115     // For unknown shape, create a rank 1 size 0 tensor.
116     return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
117   }
118   int rank = tensor_shape.dims();
119   std::vector<int64> dimensions(rank);
120   std::vector<bool> dynamic_dimensions(rank, false);
121   std::vector<int64> layout(rank);
122   for (int d = 0; d < rank; ++d) {
123     dimensions[d] = tensor_shape.dim_size(d);
124     if (dimensions[d] < 0) {
125       dynamic_dimensions[d] = true;
126       // TODO(b/177329258): Consider improving this/enabling MakeShapeWithLayout
127       // to work wuith dynamic shapes.
128       LOG(WARNING) << "Unable to convert TF shape with dynamic size to XLA "
129                       "shape; returning unknown sentinel value";
130       return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
131     }
132   }
133   // XLA uses minor-to-major; Tensorflow uses major-to-minor.
134   std::iota(layout.rbegin(), layout.rend(), 0);
135   xla::Shape result =
136       xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
137 
138   for (int64 d = 0; d < rank; ++d) {
139     result.set_dynamic_dimension(d, dynamic_dimensions[d]);
140   }
141   return result;
142 }
143 
144 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape,xla::Shape * shape)145 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
146                              xla::Shape* shape) {
147   xla::PrimitiveType type;
148   TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
149   *shape = TensorShapeToXLAShape(type, tensor_shape);
150   return Status::OK();
151 }
152 
TensorShapeToXLAShape(xla::PrimitiveType type,const TensorShape & tensor_shape)153 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
154                                  const TensorShape& tensor_shape) {
155   int rank = tensor_shape.dims();
156   std::vector<int64> dimensions(rank);
157   std::vector<int64> layout(rank);
158   for (int d = 0; d < rank; ++d) {
159     dimensions[d] = tensor_shape.dim_size(d);
160   }
161   // XLA uses minor-to-major; Tensorflow uses major-to-minor.
162   std::iota(layout.rbegin(), layout.rend(), 0);
163 
164   return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
165 }
166 
GetShapeLayoutVector(const xla::Shape & shape)167 xla::StatusOr<std::vector<int>> GetShapeLayoutVector(const xla::Shape& shape) {
168   std::vector<int> layouts;
169   TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts));
170   return layouts;
171 }
172 
GetShapeWithLayout(const xla::Shape & input_shape,absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * output_shape)173 Status GetShapeWithLayout(
174     const xla::Shape& input_shape, absl::Span<const int64> minor_to_major,
175     const std::function<xla::Layout(const xla::Shape&)>& layout_func,
176     xla::Shape* output_shape) {
177   if (input_shape.IsTuple()) {
178     int64 tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape);
179     std::vector<xla::Shape> shapes;
180     shapes.reserve(tuple_elements);
181     size_t position = 0;
182     for (int64 i = 0; i < tuple_elements; ++i) {
183       const xla::Shape& shape =
184           xla::ShapeUtil::GetTupleElementShape(input_shape, i);
185       if (shape.IsTuple()) {
186         return errors::InvalidArgument(
187             "Nested tuples not supported: ",
188             xla::ShapeUtil::HumanString(input_shape));
189       }
190       int64 rank = shape.rank();
191       if (position + rank > minor_to_major.size()) {
192         return errors::InvalidArgument(
193             "Not enough layout attribute elements: position=", position,
194             " rank=", rank, " elements=", minor_to_major.size());
195       }
196       shapes.push_back(shape);
197       TF_RETURN_IF_ERROR(AssignLayout(
198           absl::Span<const int64>(minor_to_major).subspan(position, rank),
199           layout_func, &shapes.back()));
200       position += rank;
201 
202       VLOG(4) << "Shape[" << i
203               << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back());
204     }
205     if (position != minor_to_major.size()) {
206       return errors::InvalidArgument(
207           "Too many elements passed in the layout attribute: position=",
208           position, " size=", minor_to_major.size());
209     }
210     *output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
211   } else {
212     int64 rank = input_shape.rank();
213     const int64 minor_to_major_size = minor_to_major.size();
214     if (rank != minor_to_major_size) {
215       return errors::InvalidArgument(
216           "Wrong number of layout attribute elements: rank=", rank,
217           " elements=", minor_to_major.size());
218     }
219     *output_shape = input_shape;
220     TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape));
221 
222     VLOG(4) << "Shape[] = "
223             << xla::ShapeUtil::HumanStringWithLayout(*output_shape);
224   }
225   return Status::OK();
226 }
227 
228 }  // namespace tensorflow
229